| |
| """Train per-emotion fusion weights via gradient descent. |
| |
| Inputs: |
| - --manifest JSON: list of {"path","text","label","source",...} (2,821 samples) |
| - --preds-cache JSON: {"audio_preds": [dict(7)], "text_preds": [dict(7)]} |
| |
| Output dir receives: |
| - trained_weights.json — learned w_a / w_t / val_macro_f1 |
| - trained_fusion_report.md — comparison: audio-only, fixed 60/40, greedy optimal, trained |
| - trained_fusion_curve.png — train/val loss + F1 curves |
| |
| Parameterization: |
| w_a[L] = sigmoid(α[L]), w_t[L] = 1 - w_a[L] # 7 params total |
| fused[L] = p_a[L]*w_a[L] + p_t[L]*w_t[L] |
| fused ← normalize over L |
| loss = NLL(log fused, y) + λ * ||α||² |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from collections import Counter |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from sklearn.metrics import f1_score |
| from sklearn.model_selection import train_test_split |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] |
|
|
|
|
| class FusionHead(nn.Module): |
| def __init__(self, init_audio_frac: float = 0.6): |
| super().__init__() |
| init_val = float(torch.logit(torch.tensor(init_audio_frac))) |
| self.alpha = nn.Parameter(torch.full((7,), init_val)) |
|
|
| @property |
| def w_a(self) -> torch.Tensor: |
| return torch.sigmoid(self.alpha) |
|
|
| @property |
| def w_t(self) -> torch.Tensor: |
| return 1.0 - self.w_a |
|
|
| def forward(self, p_a: torch.Tensor, p_t: torch.Tensor) -> torch.Tensor: |
| fused = p_a * self.w_a + p_t * self.w_t |
| return fused / fused.sum(dim=1, keepdim=True).clamp(min=1e-8) |
|
|
|
|
| def probs_to_tensor(preds: list[dict]) -> torch.Tensor: |
| arr = np.array([[p.get(l, 0.0) for l in PROJECT_LABELS] for p in preds], dtype=np.float32) |
| return torch.from_numpy(arr) |
|
|
|
|
| def map_label(lbl: str) -> str: |
| return "joy" if lbl == "happiness" else lbl |
|
|
|
|
| def eval_weights(p_a: torch.Tensor, p_t: torch.Tensor, y: np.ndarray, |
| w_a_vec: np.ndarray) -> dict: |
| w_a = torch.from_numpy(w_a_vec.astype(np.float32)) |
| w_t = 1.0 - w_a |
| fused = p_a * w_a + p_t * w_t |
| fused = fused / fused.sum(dim=1, keepdim=True).clamp(min=1e-8) |
| pred = fused.argmax(dim=1).numpy() |
| macro = f1_score(y, pred, average="macro") |
| per_class = { |
| PROJECT_LABELS[i]: f1_score((y == i).astype(int), (pred == i).astype(int)) |
| for i in range(7) |
| } |
| return {"macro_f1": float(macro), "per_class": {k: float(v) for k, v in per_class.items()}} |
|
|
|
|
| def train(p_a_tr, p_t_tr, y_tr, p_a_vl, p_t_vl, y_vl, |
| lr=0.05, epochs=500, l2=0.01, patience=50): |
| model = FusionHead() |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| nll = nn.NLLLoss() |
|
|
| history = {"train_loss": [], "val_f1": []} |
| best_f1, best_alpha, waited = -1.0, None, 0 |
|
|
| y_tr_t = torch.from_numpy(y_tr).long() |
| y_vl_t = torch.from_numpy(y_vl).long() |
|
|
| for epoch in range(epochs): |
| model.train() |
| opt.zero_grad() |
| fused = model(p_a_tr, p_t_tr) |
| loss = nll(torch.log(fused.clamp(min=1e-8)), y_tr_t) + l2 * (model.alpha ** 2).sum() |
| loss.backward() |
| opt.step() |
|
|
| model.eval() |
| with torch.no_grad(): |
| val_fused = model(p_a_vl, p_t_vl) |
| val_pred = val_fused.argmax(dim=1).numpy() |
| val_f1 = f1_score(y_vl, val_pred, average="macro") |
|
|
| history["train_loss"].append(float(loss.item())) |
| history["val_f1"].append(float(val_f1)) |
|
|
| if val_f1 > best_f1: |
| best_f1, best_alpha, waited = float(val_f1), model.alpha.detach().clone(), 0 |
| else: |
| waited += 1 |
| if waited >= patience: |
| logger.info("Early stop at epoch %d (patience=%d)", epoch, patience) |
| break |
|
|
| final_alpha = model.alpha.detach().clone() |
| model.alpha.data = best_alpha |
| return model, best_f1, history, final_alpha |
|
|
|
|
| def plot_curve(history, output_path: Path) -> None: |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) |
| ax1.plot(history["train_loss"], color="#F44336", label="train CE + L2") |
| ax1.set_xlabel("Epoch"); ax1.set_ylabel("Loss"); ax1.set_title("Training loss") |
| ax1.grid(alpha=0.3); ax1.legend() |
|
|
| ax2.plot(history["val_f1"], color="#4CAF50", label="val macro F1") |
| ax2.set_xlabel("Epoch"); ax2.set_ylabel("Macro F1"); ax2.set_title("Validation macro F1") |
| ax2.grid(alpha=0.3); ax2.legend() |
|
|
| plt.tight_layout() |
| plt.savefig(str(output_path), dpi=150) |
| plt.close() |
|
|
|
|
| def write_report(output_path: Path, num_train: int, num_val: int, labels_dist: dict, |
| audio_only: dict, fixed: dict, greedy: dict, trained: dict, |
| trained_weights: dict, greedy_weights: dict) -> None: |
| lines = ["# Fusion Weight Training Report (v2)\n"] |
| lines.append(f"## Dataset\n\nTotal samples: **{num_train + num_val}** (train {num_train}, val {num_val})\n") |
| lines.append("### Label distribution\n\n| Label | Count |\n|---|---|") |
| for lbl, c in sorted(labels_dist.items(), key=lambda x: -x[1]): |
| lines.append(f"| {lbl} | {c} |") |
| lines.append("") |
| lines.append("## Macro F1 Comparison (validation set)\n") |
| lines.append("| Strategy | Macro F1 |") |
| lines.append("|---|---|") |
| lines.append(f"| Audio-only (argmax p_audio) | {audio_only['macro_f1']:.4f} |") |
| lines.append(f"| Fixed 60/40 | {fixed['macro_f1']:.4f} |") |
| lines.append(f"| Greedy grid (v1 weights) | {greedy['macro_f1']:.4f} |") |
| lines.append(f"| **Trained (gradient descent)** | **{trained['macro_f1']:.4f}** |") |
| lines.append("") |
| lines.append("## Per-class F1 (validation set)\n") |
| lines.append("| Emotion | Audio-only | Fixed 60/40 | Greedy | Trained |") |
| lines.append("|---|---|---|---|---|") |
| for lbl in PROJECT_LABELS: |
| lines.append(f"| {lbl} | {audio_only['per_class'][lbl]:.4f} | " |
| f"{fixed['per_class'][lbl]:.4f} | {greedy['per_class'][lbl]:.4f} | " |
| f"{trained['per_class'][lbl]:.4f} |") |
| lines.append("") |
| lines.append("## Learned weights\n") |
| lines.append("| Emotion | Audio (trained) | Text (trained) | Audio (greedy v1) |") |
| lines.append("|---|---|---|---|") |
| for lbl in PROJECT_LABELS: |
| w = trained_weights[lbl] |
| gw = greedy_weights.get(lbl, {"audio": None}) |
| g_a = f"{gw['audio']:.2f}" if gw.get("audio") is not None else "—" |
| lines.append(f"| {lbl} | {w['audio']:.2f} | {w['text']:.2f} | {g_a} |") |
| lines.append("") |
| output_path.write_text("\n".join(lines)) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--manifest", type=Path, required=True) |
| parser.add_argument("--preds-cache", type=Path, required=True) |
| parser.add_argument("--output-dir", type=Path, required=True) |
| parser.add_argument("--lr", type=float, default=0.05) |
| parser.add_argument("--epochs", type=int, default=500) |
| parser.add_argument("--l2", type=float, default=0.01) |
| parser.add_argument("--patience", type=int, default=50) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--use-last-alpha", action="store_true", |
| help="Use final-epoch alpha instead of best-val-F1 alpha (keeps differentiation)") |
| args = parser.parse_args() |
|
|
| args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| manifest = json.loads(args.manifest.read_text()) |
| cache = json.loads(args.preds_cache.read_text()) |
| audio_preds = cache["audio_preds"] |
| text_preds = cache["text_preds"] |
|
|
| if not (len(manifest) == len(audio_preds) == len(text_preds)): |
| raise ValueError(f"Size mismatch: manifest={len(manifest)}, audio={len(audio_preds)}, text={len(text_preds)}") |
|
|
| labels = [map_label(r["label"]) for r in manifest] |
| labels_dist = dict(Counter(labels)) |
| logger.info("Loaded %d samples. Distribution: %s", len(manifest), labels_dist) |
|
|
| p_a = probs_to_tensor(audio_preds) |
| p_t = probs_to_tensor(text_preds) |
| y = np.array([PROJECT_LABELS.index(l) for l in labels]) |
|
|
| |
| tr_idx, vl_idx = train_test_split( |
| np.arange(len(y)), test_size=0.2, stratify=y, random_state=args.seed, |
| ) |
| logger.info("Train/Val: %d / %d", len(tr_idx), len(vl_idx)) |
|
|
| torch.manual_seed(args.seed) |
| model, best_f1, history, final_alpha = train( |
| p_a[tr_idx], p_t[tr_idx], y[tr_idx], |
| p_a[vl_idx], p_t[vl_idx], y[vl_idx], |
| lr=args.lr, epochs=args.epochs, l2=args.l2, patience=args.patience, |
| ) |
| logger.info("Best val macro F1: %.4f", best_f1) |
|
|
| |
| if args.use_last_alpha: |
| model.alpha.data = final_alpha |
| logger.info("Using LAST-epoch alpha (--use-last-alpha)") |
| |
| w_a_np = model.w_a.detach().numpy() |
| trained_weights = { |
| PROJECT_LABELS[i]: {"audio": round(float(w_a_np[i]), 2), "text": round(float(1 - w_a_np[i]), 2)} |
| for i in range(7) |
| } |
|
|
| |
| last_w_a = torch.sigmoid(final_alpha).numpy() |
| last_weights = { |
| PROJECT_LABELS[i]: {"audio": round(float(last_w_a[i]), 2), "text": round(float(1 - last_w_a[i]), 2)} |
| for i in range(7) |
| } |
|
|
| |
| pa_vl = p_a[vl_idx]; pt_vl = p_t[vl_idx]; y_vl = y[vl_idx] |
|
|
| |
| audio_only_pred = pa_vl.argmax(dim=1).numpy() |
| audio_only = { |
| "macro_f1": float(f1_score(y_vl, audio_only_pred, average="macro")), |
| "per_class": { |
| PROJECT_LABELS[i]: float(f1_score((y_vl == i).astype(int), (audio_only_pred == i).astype(int))) |
| for i in range(7) |
| }, |
| } |
|
|
| fixed = eval_weights(pa_vl, pt_vl, y_vl, np.full(7, 0.6)) |
| |
| GREEDY_V1 = { |
| "neutral": 0.75, "joy": 0.55, "sadness": 0.40, "anger": 0.65, |
| "surprise": 0.45, "fear": 0.00, "disgust": 0.80, |
| } |
| greedy_w_a = np.array([GREEDY_V1[l] for l in PROJECT_LABELS]) |
| greedy = eval_weights(pa_vl, pt_vl, y_vl, greedy_w_a) |
| greedy_weights = {l: {"audio": GREEDY_V1[l]} for l in PROJECT_LABELS} |
|
|
| trained = eval_weights(pa_vl, pt_vl, y_vl, w_a_np) |
|
|
| logger.info("Audio-only val macro F1: %.4f", audio_only["macro_f1"]) |
| logger.info("Fixed 60/40 val macro F1: %.4f", fixed["macro_f1"]) |
| logger.info("Greedy v1 val macro F1: %.4f", greedy["macro_f1"]) |
| logger.info("Trained val macro F1: %.4f", trained["macro_f1"]) |
|
|
| |
| (args.output_dir / "trained_weights.json").write_text(json.dumps({ |
| "val_macro_f1": trained["macro_f1"], |
| "weights": trained_weights, |
| "last_epoch_weights": last_weights, |
| "train_size": int(len(tr_idx)), |
| "val_size": int(len(vl_idx)), |
| "hyperparams": {"lr": args.lr, "l2": args.l2, "epochs": args.epochs, |
| "patience": args.patience, "seed": args.seed, |
| "use_last_alpha": args.use_last_alpha}, |
| }, indent=2, ensure_ascii=False)) |
|
|
| plot_curve(history, args.output_dir / "trained_fusion_curve.png") |
| write_report( |
| args.output_dir / "trained_fusion_report.md", |
| len(tr_idx), len(vl_idx), labels_dist, |
| audio_only, fixed, greedy, trained, |
| trained_weights, greedy_weights, |
| ) |
| logger.info("Done. Results saved to %s", args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|