ustwo-api / scripts /train_fusion_weights.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
11.7 kB
#!/usr/bin/env python3
"""Train per-emotion fusion weights via gradient descent.
Inputs:
- --manifest JSON: list of {"path","text","label","source",...} (2,821 samples)
- --preds-cache JSON: {"audio_preds": [dict(7)], "text_preds": [dict(7)]}
Output dir receives:
- trained_weights.json — learned w_a / w_t / val_macro_f1
- trained_fusion_report.md — comparison: audio-only, fixed 60/40, greedy optimal, trained
- trained_fusion_curve.png — train/val loss + F1 curves
Parameterization:
w_a[L] = sigmoid(α[L]), w_t[L] = 1 - w_a[L] # 7 params total
fused[L] = p_a[L]*w_a[L] + p_t[L]*w_t[L]
fused ← normalize over L
loss = NLL(log fused, y) + λ * ||α||²
"""
from __future__ import annotations
import argparse
import json
import logging
from collections import Counter
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"]
class FusionHead(nn.Module):
def __init__(self, init_audio_frac: float = 0.6):
super().__init__()
init_val = float(torch.logit(torch.tensor(init_audio_frac)))
self.alpha = nn.Parameter(torch.full((7,), init_val))
@property
def w_a(self) -> torch.Tensor:
return torch.sigmoid(self.alpha)
@property
def w_t(self) -> torch.Tensor:
return 1.0 - self.w_a
def forward(self, p_a: torch.Tensor, p_t: torch.Tensor) -> torch.Tensor:
fused = p_a * self.w_a + p_t * self.w_t
return fused / fused.sum(dim=1, keepdim=True).clamp(min=1e-8)
def probs_to_tensor(preds: list[dict]) -> torch.Tensor:
arr = np.array([[p.get(l, 0.0) for l in PROJECT_LABELS] for p in preds], dtype=np.float32)
return torch.from_numpy(arr)
def map_label(lbl: str) -> str:
return "joy" if lbl == "happiness" else lbl
def eval_weights(p_a: torch.Tensor, p_t: torch.Tensor, y: np.ndarray,
w_a_vec: np.ndarray) -> dict:
w_a = torch.from_numpy(w_a_vec.astype(np.float32))
w_t = 1.0 - w_a
fused = p_a * w_a + p_t * w_t
fused = fused / fused.sum(dim=1, keepdim=True).clamp(min=1e-8)
pred = fused.argmax(dim=1).numpy()
macro = f1_score(y, pred, average="macro")
per_class = {
PROJECT_LABELS[i]: f1_score((y == i).astype(int), (pred == i).astype(int))
for i in range(7)
}
return {"macro_f1": float(macro), "per_class": {k: float(v) for k, v in per_class.items()}}
def train(p_a_tr, p_t_tr, y_tr, p_a_vl, p_t_vl, y_vl,
lr=0.05, epochs=500, l2=0.01, patience=50):
model = FusionHead()
opt = torch.optim.Adam(model.parameters(), lr=lr)
nll = nn.NLLLoss()
history = {"train_loss": [], "val_f1": []}
best_f1, best_alpha, waited = -1.0, None, 0
y_tr_t = torch.from_numpy(y_tr).long()
y_vl_t = torch.from_numpy(y_vl).long()
for epoch in range(epochs):
model.train()
opt.zero_grad()
fused = model(p_a_tr, p_t_tr)
loss = nll(torch.log(fused.clamp(min=1e-8)), y_tr_t) + l2 * (model.alpha ** 2).sum()
loss.backward()
opt.step()
model.eval()
with torch.no_grad():
val_fused = model(p_a_vl, p_t_vl)
val_pred = val_fused.argmax(dim=1).numpy()
val_f1 = f1_score(y_vl, val_pred, average="macro")
history["train_loss"].append(float(loss.item()))
history["val_f1"].append(float(val_f1))
if val_f1 > best_f1:
best_f1, best_alpha, waited = float(val_f1), model.alpha.detach().clone(), 0
else:
waited += 1
if waited >= patience:
logger.info("Early stop at epoch %d (patience=%d)", epoch, patience)
break
final_alpha = model.alpha.detach().clone()
model.alpha.data = best_alpha
return model, best_f1, history, final_alpha
def plot_curve(history, output_path: Path) -> None:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history["train_loss"], color="#F44336", label="train CE + L2")
ax1.set_xlabel("Epoch"); ax1.set_ylabel("Loss"); ax1.set_title("Training loss")
ax1.grid(alpha=0.3); ax1.legend()
ax2.plot(history["val_f1"], color="#4CAF50", label="val macro F1")
ax2.set_xlabel("Epoch"); ax2.set_ylabel("Macro F1"); ax2.set_title("Validation macro F1")
ax2.grid(alpha=0.3); ax2.legend()
plt.tight_layout()
plt.savefig(str(output_path), dpi=150)
plt.close()
def write_report(output_path: Path, num_train: int, num_val: int, labels_dist: dict,
audio_only: dict, fixed: dict, greedy: dict, trained: dict,
trained_weights: dict, greedy_weights: dict) -> None:
lines = ["# Fusion Weight Training Report (v2)\n"]
lines.append(f"## Dataset\n\nTotal samples: **{num_train + num_val}** (train {num_train}, val {num_val})\n")
lines.append("### Label distribution\n\n| Label | Count |\n|---|---|")
for lbl, c in sorted(labels_dist.items(), key=lambda x: -x[1]):
lines.append(f"| {lbl} | {c} |")
lines.append("")
lines.append("## Macro F1 Comparison (validation set)\n")
lines.append("| Strategy | Macro F1 |")
lines.append("|---|---|")
lines.append(f"| Audio-only (argmax p_audio) | {audio_only['macro_f1']:.4f} |")
lines.append(f"| Fixed 60/40 | {fixed['macro_f1']:.4f} |")
lines.append(f"| Greedy grid (v1 weights) | {greedy['macro_f1']:.4f} |")
lines.append(f"| **Trained (gradient descent)** | **{trained['macro_f1']:.4f}** |")
lines.append("")
lines.append("## Per-class F1 (validation set)\n")
lines.append("| Emotion | Audio-only | Fixed 60/40 | Greedy | Trained |")
lines.append("|---|---|---|---|---|")
for lbl in PROJECT_LABELS:
lines.append(f"| {lbl} | {audio_only['per_class'][lbl]:.4f} | "
f"{fixed['per_class'][lbl]:.4f} | {greedy['per_class'][lbl]:.4f} | "
f"{trained['per_class'][lbl]:.4f} |")
lines.append("")
lines.append("## Learned weights\n")
lines.append("| Emotion | Audio (trained) | Text (trained) | Audio (greedy v1) |")
lines.append("|---|---|---|---|")
for lbl in PROJECT_LABELS:
w = trained_weights[lbl]
gw = greedy_weights.get(lbl, {"audio": None})
g_a = f"{gw['audio']:.2f}" if gw.get("audio") is not None else "—"
lines.append(f"| {lbl} | {w['audio']:.2f} | {w['text']:.2f} | {g_a} |")
lines.append("")
output_path.write_text("\n".join(lines))
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--manifest", type=Path, required=True)
parser.add_argument("--preds-cache", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--lr", type=float, default=0.05)
parser.add_argument("--epochs", type=int, default=500)
parser.add_argument("--l2", type=float, default=0.01)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--use-last-alpha", action="store_true",
help="Use final-epoch alpha instead of best-val-F1 alpha (keeps differentiation)")
args = parser.parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
manifest = json.loads(args.manifest.read_text())
cache = json.loads(args.preds_cache.read_text())
audio_preds = cache["audio_preds"]
text_preds = cache["text_preds"]
if not (len(manifest) == len(audio_preds) == len(text_preds)):
raise ValueError(f"Size mismatch: manifest={len(manifest)}, audio={len(audio_preds)}, text={len(text_preds)}")
labels = [map_label(r["label"]) for r in manifest]
labels_dist = dict(Counter(labels))
logger.info("Loaded %d samples. Distribution: %s", len(manifest), labels_dist)
p_a = probs_to_tensor(audio_preds)
p_t = probs_to_tensor(text_preds)
y = np.array([PROJECT_LABELS.index(l) for l in labels])
# Stratified 80/20
tr_idx, vl_idx = train_test_split(
np.arange(len(y)), test_size=0.2, stratify=y, random_state=args.seed,
)
logger.info("Train/Val: %d / %d", len(tr_idx), len(vl_idx))
torch.manual_seed(args.seed)
model, best_f1, history, final_alpha = train(
p_a[tr_idx], p_t[tr_idx], y[tr_idx],
p_a[vl_idx], p_t[vl_idx], y[vl_idx],
lr=args.lr, epochs=args.epochs, l2=args.l2, patience=args.patience,
)
logger.info("Best val macro F1: %.4f", best_f1)
# Select which alpha to deploy: "best" (peak val F1) or "last" (final epoch)
if args.use_last_alpha:
model.alpha.data = final_alpha
logger.info("Using LAST-epoch alpha (--use-last-alpha)")
# Derive trained weights dict
w_a_np = model.w_a.detach().numpy()
trained_weights = {
PROJECT_LABELS[i]: {"audio": round(float(w_a_np[i]), 2), "text": round(float(1 - w_a_np[i]), 2)}
for i in range(7)
}
# Also compute last-alpha weights for comparison
last_w_a = torch.sigmoid(final_alpha).numpy()
last_weights = {
PROJECT_LABELS[i]: {"audio": round(float(last_w_a[i]), 2), "text": round(float(1 - last_w_a[i]), 2)}
for i in range(7)
}
# Baselines on val split
pa_vl = p_a[vl_idx]; pt_vl = p_t[vl_idx]; y_vl = y[vl_idx]
# Audio-only argmax
audio_only_pred = pa_vl.argmax(dim=1).numpy()
audio_only = {
"macro_f1": float(f1_score(y_vl, audio_only_pred, average="macro")),
"per_class": {
PROJECT_LABELS[i]: float(f1_score((y_vl == i).astype(int), (audio_only_pred == i).astype(int)))
for i in range(7)
},
}
fixed = eval_weights(pa_vl, pt_vl, y_vl, np.full(7, 0.6))
# Greedy v1 weights — hardcode from previous report
GREEDY_V1 = {
"neutral": 0.75, "joy": 0.55, "sadness": 0.40, "anger": 0.65,
"surprise": 0.45, "fear": 0.00, "disgust": 0.80,
}
greedy_w_a = np.array([GREEDY_V1[l] for l in PROJECT_LABELS])
greedy = eval_weights(pa_vl, pt_vl, y_vl, greedy_w_a)
greedy_weights = {l: {"audio": GREEDY_V1[l]} for l in PROJECT_LABELS}
trained = eval_weights(pa_vl, pt_vl, y_vl, w_a_np)
logger.info("Audio-only val macro F1: %.4f", audio_only["macro_f1"])
logger.info("Fixed 60/40 val macro F1: %.4f", fixed["macro_f1"])
logger.info("Greedy v1 val macro F1: %.4f", greedy["macro_f1"])
logger.info("Trained val macro F1: %.4f", trained["macro_f1"])
# Save
(args.output_dir / "trained_weights.json").write_text(json.dumps({
"val_macro_f1": trained["macro_f1"],
"weights": trained_weights,
"last_epoch_weights": last_weights,
"train_size": int(len(tr_idx)),
"val_size": int(len(vl_idx)),
"hyperparams": {"lr": args.lr, "l2": args.l2, "epochs": args.epochs,
"patience": args.patience, "seed": args.seed,
"use_last_alpha": args.use_last_alpha},
}, indent=2, ensure_ascii=False))
plot_curve(history, args.output_dir / "trained_fusion_curve.png")
write_report(
args.output_dir / "trained_fusion_report.md",
len(tr_idx), len(vl_idx), labels_dist,
audio_only, fixed, greedy, trained,
trained_weights, greedy_weights,
)
logger.info("Done. Results saved to %s", args.output_dir)
if __name__ == "__main__":
main()