#!/usr/bin/env python3 """Train per-emotion fusion weights via gradient descent. Inputs: - --manifest JSON: list of {"path","text","label","source",...} (2,821 samples) - --preds-cache JSON: {"audio_preds": [dict(7)], "text_preds": [dict(7)]} Output dir receives: - trained_weights.json — learned w_a / w_t / val_macro_f1 - trained_fusion_report.md — comparison: audio-only, fixed 60/40, greedy optimal, trained - trained_fusion_curve.png — train/val loss + F1 curves Parameterization: w_a[L] = sigmoid(α[L]), w_t[L] = 1 - w_a[L] # 7 params total fused[L] = p_a[L]*w_a[L] + p_t[L]*w_t[L] fused ← normalize over L loss = NLL(log fused, y) + λ * ||α||² """ from __future__ import annotations import argparse import json import logging from collections import Counter from pathlib import Path import numpy as np import torch import torch.nn as nn from sklearn.metrics import f1_score from sklearn.model_selection import train_test_split logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] class FusionHead(nn.Module): def __init__(self, init_audio_frac: float = 0.6): super().__init__() init_val = float(torch.logit(torch.tensor(init_audio_frac))) self.alpha = nn.Parameter(torch.full((7,), init_val)) @property def w_a(self) -> torch.Tensor: return torch.sigmoid(self.alpha) @property def w_t(self) -> torch.Tensor: return 1.0 - self.w_a def forward(self, p_a: torch.Tensor, p_t: torch.Tensor) -> torch.Tensor: fused = p_a * self.w_a + p_t * self.w_t return fused / fused.sum(dim=1, keepdim=True).clamp(min=1e-8) def probs_to_tensor(preds: list[dict]) -> torch.Tensor: arr = np.array([[p.get(l, 0.0) for l in PROJECT_LABELS] for p in preds], dtype=np.float32) return torch.from_numpy(arr) def map_label(lbl: str) -> str: return "joy" if lbl == "happiness" else lbl def eval_weights(p_a: torch.Tensor, p_t: torch.Tensor, y: np.ndarray, w_a_vec: np.ndarray) -> dict: w_a = torch.from_numpy(w_a_vec.astype(np.float32)) w_t = 1.0 - w_a fused = p_a * w_a + p_t * w_t fused = fused / fused.sum(dim=1, keepdim=True).clamp(min=1e-8) pred = fused.argmax(dim=1).numpy() macro = f1_score(y, pred, average="macro") per_class = { PROJECT_LABELS[i]: f1_score((y == i).astype(int), (pred == i).astype(int)) for i in range(7) } return {"macro_f1": float(macro), "per_class": {k: float(v) for k, v in per_class.items()}} def train(p_a_tr, p_t_tr, y_tr, p_a_vl, p_t_vl, y_vl, lr=0.05, epochs=500, l2=0.01, patience=50): model = FusionHead() opt = torch.optim.Adam(model.parameters(), lr=lr) nll = nn.NLLLoss() history = {"train_loss": [], "val_f1": []} best_f1, best_alpha, waited = -1.0, None, 0 y_tr_t = torch.from_numpy(y_tr).long() y_vl_t = torch.from_numpy(y_vl).long() for epoch in range(epochs): model.train() opt.zero_grad() fused = model(p_a_tr, p_t_tr) loss = nll(torch.log(fused.clamp(min=1e-8)), y_tr_t) + l2 * (model.alpha ** 2).sum() loss.backward() opt.step() model.eval() with torch.no_grad(): val_fused = model(p_a_vl, p_t_vl) val_pred = val_fused.argmax(dim=1).numpy() val_f1 = f1_score(y_vl, val_pred, average="macro") history["train_loss"].append(float(loss.item())) history["val_f1"].append(float(val_f1)) if val_f1 > best_f1: best_f1, best_alpha, waited = float(val_f1), model.alpha.detach().clone(), 0 else: waited += 1 if waited >= patience: logger.info("Early stop at epoch %d (patience=%d)", epoch, patience) break final_alpha = model.alpha.detach().clone() model.alpha.data = best_alpha return model, best_f1, history, final_alpha def plot_curve(history, output_path: Path) -> None: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) ax1.plot(history["train_loss"], color="#F44336", label="train CE + L2") ax1.set_xlabel("Epoch"); ax1.set_ylabel("Loss"); ax1.set_title("Training loss") ax1.grid(alpha=0.3); ax1.legend() ax2.plot(history["val_f1"], color="#4CAF50", label="val macro F1") ax2.set_xlabel("Epoch"); ax2.set_ylabel("Macro F1"); ax2.set_title("Validation macro F1") ax2.grid(alpha=0.3); ax2.legend() plt.tight_layout() plt.savefig(str(output_path), dpi=150) plt.close() def write_report(output_path: Path, num_train: int, num_val: int, labels_dist: dict, audio_only: dict, fixed: dict, greedy: dict, trained: dict, trained_weights: dict, greedy_weights: dict) -> None: lines = ["# Fusion Weight Training Report (v2)\n"] lines.append(f"## Dataset\n\nTotal samples: **{num_train + num_val}** (train {num_train}, val {num_val})\n") lines.append("### Label distribution\n\n| Label | Count |\n|---|---|") for lbl, c in sorted(labels_dist.items(), key=lambda x: -x[1]): lines.append(f"| {lbl} | {c} |") lines.append("") lines.append("## Macro F1 Comparison (validation set)\n") lines.append("| Strategy | Macro F1 |") lines.append("|---|---|") lines.append(f"| Audio-only (argmax p_audio) | {audio_only['macro_f1']:.4f} |") lines.append(f"| Fixed 60/40 | {fixed['macro_f1']:.4f} |") lines.append(f"| Greedy grid (v1 weights) | {greedy['macro_f1']:.4f} |") lines.append(f"| **Trained (gradient descent)** | **{trained['macro_f1']:.4f}** |") lines.append("") lines.append("## Per-class F1 (validation set)\n") lines.append("| Emotion | Audio-only | Fixed 60/40 | Greedy | Trained |") lines.append("|---|---|---|---|---|") for lbl in PROJECT_LABELS: lines.append(f"| {lbl} | {audio_only['per_class'][lbl]:.4f} | " f"{fixed['per_class'][lbl]:.4f} | {greedy['per_class'][lbl]:.4f} | " f"{trained['per_class'][lbl]:.4f} |") lines.append("") lines.append("## Learned weights\n") lines.append("| Emotion | Audio (trained) | Text (trained) | Audio (greedy v1) |") lines.append("|---|---|---|---|") for lbl in PROJECT_LABELS: w = trained_weights[lbl] gw = greedy_weights.get(lbl, {"audio": None}) g_a = f"{gw['audio']:.2f}" if gw.get("audio") is not None else "—" lines.append(f"| {lbl} | {w['audio']:.2f} | {w['text']:.2f} | {g_a} |") lines.append("") output_path.write_text("\n".join(lines)) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--manifest", type=Path, required=True) parser.add_argument("--preds-cache", type=Path, required=True) parser.add_argument("--output-dir", type=Path, required=True) parser.add_argument("--lr", type=float, default=0.05) parser.add_argument("--epochs", type=int, default=500) parser.add_argument("--l2", type=float, default=0.01) parser.add_argument("--patience", type=int, default=50) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--use-last-alpha", action="store_true", help="Use final-epoch alpha instead of best-val-F1 alpha (keeps differentiation)") args = parser.parse_args() args.output_dir.mkdir(parents=True, exist_ok=True) manifest = json.loads(args.manifest.read_text()) cache = json.loads(args.preds_cache.read_text()) audio_preds = cache["audio_preds"] text_preds = cache["text_preds"] if not (len(manifest) == len(audio_preds) == len(text_preds)): raise ValueError(f"Size mismatch: manifest={len(manifest)}, audio={len(audio_preds)}, text={len(text_preds)}") labels = [map_label(r["label"]) for r in manifest] labels_dist = dict(Counter(labels)) logger.info("Loaded %d samples. Distribution: %s", len(manifest), labels_dist) p_a = probs_to_tensor(audio_preds) p_t = probs_to_tensor(text_preds) y = np.array([PROJECT_LABELS.index(l) for l in labels]) # Stratified 80/20 tr_idx, vl_idx = train_test_split( np.arange(len(y)), test_size=0.2, stratify=y, random_state=args.seed, ) logger.info("Train/Val: %d / %d", len(tr_idx), len(vl_idx)) torch.manual_seed(args.seed) model, best_f1, history, final_alpha = train( p_a[tr_idx], p_t[tr_idx], y[tr_idx], p_a[vl_idx], p_t[vl_idx], y[vl_idx], lr=args.lr, epochs=args.epochs, l2=args.l2, patience=args.patience, ) logger.info("Best val macro F1: %.4f", best_f1) # Select which alpha to deploy: "best" (peak val F1) or "last" (final epoch) if args.use_last_alpha: model.alpha.data = final_alpha logger.info("Using LAST-epoch alpha (--use-last-alpha)") # Derive trained weights dict w_a_np = model.w_a.detach().numpy() trained_weights = { PROJECT_LABELS[i]: {"audio": round(float(w_a_np[i]), 2), "text": round(float(1 - w_a_np[i]), 2)} for i in range(7) } # Also compute last-alpha weights for comparison last_w_a = torch.sigmoid(final_alpha).numpy() last_weights = { PROJECT_LABELS[i]: {"audio": round(float(last_w_a[i]), 2), "text": round(float(1 - last_w_a[i]), 2)} for i in range(7) } # Baselines on val split pa_vl = p_a[vl_idx]; pt_vl = p_t[vl_idx]; y_vl = y[vl_idx] # Audio-only argmax audio_only_pred = pa_vl.argmax(dim=1).numpy() audio_only = { "macro_f1": float(f1_score(y_vl, audio_only_pred, average="macro")), "per_class": { PROJECT_LABELS[i]: float(f1_score((y_vl == i).astype(int), (audio_only_pred == i).astype(int))) for i in range(7) }, } fixed = eval_weights(pa_vl, pt_vl, y_vl, np.full(7, 0.6)) # Greedy v1 weights — hardcode from previous report GREEDY_V1 = { "neutral": 0.75, "joy": 0.55, "sadness": 0.40, "anger": 0.65, "surprise": 0.45, "fear": 0.00, "disgust": 0.80, } greedy_w_a = np.array([GREEDY_V1[l] for l in PROJECT_LABELS]) greedy = eval_weights(pa_vl, pt_vl, y_vl, greedy_w_a) greedy_weights = {l: {"audio": GREEDY_V1[l]} for l in PROJECT_LABELS} trained = eval_weights(pa_vl, pt_vl, y_vl, w_a_np) logger.info("Audio-only val macro F1: %.4f", audio_only["macro_f1"]) logger.info("Fixed 60/40 val macro F1: %.4f", fixed["macro_f1"]) logger.info("Greedy v1 val macro F1: %.4f", greedy["macro_f1"]) logger.info("Trained val macro F1: %.4f", trained["macro_f1"]) # Save (args.output_dir / "trained_weights.json").write_text(json.dumps({ "val_macro_f1": trained["macro_f1"], "weights": trained_weights, "last_epoch_weights": last_weights, "train_size": int(len(tr_idx)), "val_size": int(len(vl_idx)), "hyperparams": {"lr": args.lr, "l2": args.l2, "epochs": args.epochs, "patience": args.patience, "seed": args.seed, "use_last_alpha": args.use_last_alpha}, }, indent=2, ensure_ascii=False)) plot_curve(history, args.output_dir / "trained_fusion_curve.png") write_report( args.output_dir / "trained_fusion_report.md", len(tr_idx), len(vl_idx), labels_dist, audio_only, fixed, greedy, trained, trained_weights, greedy_weights, ) logger.info("Done. Results saved to %s", args.output_dir) if __name__ == "__main__": main()