ustwo-api / scripts /optimize_fusion_weights.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
25.9 kB
#!/usr/bin/env python3
"""Grid Search for Emotion-Specific Fusion Weights.
Uses AI Hub 263 val split (audio + text + ground truth) to find optimal
audio/text fusion weights per emotion class.
Outputs:
- fusion_grid_search.json — full weight-F1 curves per emotion
- optimal_fusion_weights.json — best weights per emotion
- fusion_grid_search.png — 7 subplots: weight vs F1 per emotion
- fusion_comparison.png — bar chart: fixed 60/40 vs optimal
- fusion_report.md — text summary
Usage:
python scripts/optimize_fusion_weights.py \
--val-manifest data/lora_dataset/val_manifest.json \
--onnx-model data/models/lora_emotion2vec_7class/model.onnx \
--anchor-dir "data/AI Hub 263" \
--output-dir data/models/lora_emotion2vec_7class
"""
from __future__ import annotations
import argparse
import csv
import gc
import json
import logging
import sys
from collections import Counter, defaultdict
from pathlib import Path
import numpy as np
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"]
# LoRA model labels → project labels
LORA_LABELS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"]
LORA_TO_PROJECT = {
"happiness": "joy", "anger": "anger", "disgust": "disgust",
"fear": "fear", "neutral": "neutral", "sadness": "sadness", "surprise": "surprise",
}
# 263 label mapping (same as prepare_lora_dataset.py)
MAP_263 = {
"angry": "anger", "anger": "anger",
"sadness": "sadness", "sad": "sadness",
"happiness": "happiness", "happy": "happiness",
"fear": "fear", "disgust": "disgust",
"surprise": "surprise", "neutral": "neutral",
}
# KcELECTRA 44-class → 7-class (from src/stage2/text_emotion.py)
KO_LABEL_MAP = {
"기쁨": "joy", "즐거움/신남": "joy", "행복": "joy",
"감동/감탄": "joy", "고마움": "joy", "환영/호의": "joy",
"뿌듯함": "joy", "흐뭇함(귀여움/예쁨)": "joy", "기대감": "joy",
"편안/쾌적": "joy", "안심/신뢰": "joy", "아껴주는": "joy", "존경": "joy",
"놀람": "surprise", "신기함/관심": "surprise", "경악": "surprise", "어이없음": "surprise",
"슬픔": "sadness", "서러움": "sadness", "안타까움/실망": "sadness",
"절망": "sadness", "부끄러움": "sadness", "불쌍함/연민": "sadness",
"패배/자기혐오": "sadness", "힘듦/지침": "sadness", "죄책감": "sadness",
"화남/분노": "anger", "짜증": "anger", "불평/불만": "anger",
"지긋지긋": "anger", "우쭐댐/무시함": "anger", "한심함": "anger",
"증오/혐오": "anger", "귀찮음": "anger",
"공포/무서움": "fear", "불안/걱정": "fear", "당황/난처": "fear", "의심/불신": "fear",
"없음": "neutral", "깨달음": "neutral", "재미없음": "neutral",
"부담/안_내킴": "neutral", "비장함": "neutral",
"역겨움/징그러움": "disgust",
}
def load_263_texts(anchor_dir: Path) -> dict[str, str]:
"""Load wav_id → 발화문 mapping from 263 CSVs."""
texts = {}
for csv_path in sorted(anchor_dir.glob("*.csv")):
with open(csv_path, encoding="cp949") as f:
reader = csv.reader(f)
next(reader) # skip header
for row in reader:
wav_id = row[0]
text = row[1]
texts[wav_id] = text
logger.info("Loaded %d texts from 263 CSVs", len(texts))
return texts
def predict_audio_base(audio_path: str, funasr_model, max_seconds: float = 15.0) -> dict[str, float]:
"""Run base (non-finetuned) emotion2vec via FunASR, 9-class → 7-class mapping.
Audio trimmed to max_seconds — FunASR transformer has quadratic memory in sequence length,
so a 100s clip can blow past 15GB RAM. Matches predict_audio_onnx() behavior.
"""
# emotion2vec base native labels → project labels
LABEL_MAP = {
"angry": "anger", "disgusted": "disgust", "fearful": "fear",
"happy": "joy", "neutral": "neutral", "sad": "sadness", "surprised": "surprise",
"other": "neutral", "unknown": "neutral",
"生气/angry": "anger", "厌恶/disgusted": "disgust", "恐惧/fearful": "fear",
"开心/happy": "joy", "中立/neutral": "neutral", "难过/sad": "sadness",
"吃惊/surprised": "surprise", "其他/other": "neutral", "<unk>": "neutral",
}
import soundfile as sf
audio, sr = sf.read(audio_path, dtype="float32")
if audio.ndim == 2:
audio = audio.mean(axis=1)
if sr != 16000:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
max_samples = int(max_seconds * 16000)
if len(audio) > max_samples:
audio = audio[:max_samples]
try:
output = funasr_model.generate(
audio, granularity="utterance", extract_embedding=False,
)
except Exception:
return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS}
scores = {label: 0.0 for label in PROJECT_LABELS}
if output and isinstance(output, list) and len(output) > 0:
rec = output[0]
raw_labels = rec.get("labels", [])
raw_scores = rec.get("scores", [])
for native_label, score in zip(raw_labels, raw_scores):
project_label = LABEL_MAP.get(native_label, "neutral")
scores[project_label] += float(score)
total = sum(scores.values())
if total > 0:
scores = {k: v / total for k, v in scores.items()}
return scores
def predict_audio_onnx(audio_path: str, session, max_seconds: float = 15.0) -> dict[str, float]:
"""Run ONNX audio emotion prediction (trimmed to max_seconds to avoid OOM)."""
import soundfile as sf
audio, sr = sf.read(audio_path, dtype="float32")
if audio.ndim == 2:
audio = audio.mean(axis=1)
if sr != 16000:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
# Trim to max_seconds to prevent OOM on very long audio
max_samples = int(max_seconds * 16000)
if len(audio) > max_samples:
audio = audio[:max_samples]
waveform = audio.reshape(1, -1).astype(np.float32)
logits = session.run(None, {"waveform": waveform})[0]
exp_logits = np.exp(logits - logits.max(axis=-1, keepdims=True))
probs = (exp_logits / exp_logits.sum(axis=-1, keepdims=True)).squeeze()
scores = {}
for lora_label, prob in zip(LORA_LABELS, probs):
project_label = LORA_TO_PROJECT[lora_label]
scores[project_label] = float(prob)
return scores
def predict_text_onnx(text: str, tokenizer, session) -> dict[str, float]:
"""Run fine-tuned KcELECTRA ONNX text emotion prediction (7-class direct)."""
if not text or not text.strip():
return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS}
enc = tokenizer(text, return_tensors="np", truncation=True, max_length=128, padding="max_length")
logits = session.run(None, {
"input_ids": enc["input_ids"],
"attention_mask": enc["attention_mask"],
})[0]
# Softmax
exp_logits = np.exp(logits - logits.max(axis=-1, keepdims=True))
probs = (exp_logits / exp_logits.sum(axis=-1, keepdims=True)).squeeze()
# LoRA KcELECTRA labels → project labels (happiness → joy)
text_labels = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"]
text_to_project = {
"happiness": "joy", "anger": "anger", "disgust": "disgust",
"fear": "fear", "neutral": "neutral", "sadness": "sadness", "surprise": "surprise",
}
scores = {label: 0.0 for label in PROJECT_LABELS}
for tl, prob in zip(text_labels, probs):
pl = text_to_project[tl]
scores[pl] = float(prob)
return scores
def fuse_scores(audio_scores, text_scores, weights):
"""Fuse with emotion-specific weights."""
fused = {}
for label in PROJECT_LABELS:
aw = weights.get(label, {}).get("audio", 0.6)
tw = weights.get(label, {}).get("text", 0.4)
fused[label] = audio_scores.get(label, 0.0) * aw + text_scores.get(label, 0.0) * tw
total = sum(fused.values())
if total > 0:
fused = {k: v / total for k, v in fused.items()}
return fused
def compute_f1(y_true, y_pred, target_label):
"""Compute F1 for a specific label (binary: target vs rest)."""
tp = sum(1 for t, p in zip(y_true, y_pred) if t == target_label and p == target_label)
fp = sum(1 for t, p in zip(y_true, y_pred) if t != target_label and p == target_label)
fn = sum(1 for t, p in zip(y_true, y_pred) if t == target_label and p != target_label)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def compute_macro_f1(y_true, y_pred):
"""Compute macro F1 across all 7 classes."""
f1s = [compute_f1(y_true, y_pred, label) for label in PROJECT_LABELS]
return np.mean(f1s)
def grid_search(samples, audio_preds, text_preds):
"""Run grid search for emotion-specific weights.
Returns:
grid_results: dict[emotion] → list of {"audio_weight": float, "f1": float}
optimal_weights: dict[emotion] → {"audio": float, "text": float, "f1": float}
"""
weight_range = np.arange(0.0, 1.05, 0.05)
grid_results = {}
optimal_weights = {}
for target_emotion in PROJECT_LABELS:
results = []
best_f1 = -1
best_aw = 0.6
for aw in weight_range:
tw = 1.0 - aw
# Build per-emotion weight dict: target emotion uses (aw, tw), others use 0.6/0.4
weights = {}
for label in PROJECT_LABELS:
if label == target_emotion:
weights[label] = {"audio": float(aw), "text": float(tw)}
else:
weights[label] = {"audio": 0.6, "text": 0.4}
# Predict with these weights
y_true = [s["label"] for s in samples]
y_pred = []
for i, s in enumerate(samples):
fused = fuse_scores(audio_preds[i], text_preds[i], weights)
pred = max(fused, key=fused.get)
y_pred.append(pred)
f1 = compute_f1(y_true, y_pred, target_emotion)
results.append({"audio_weight": round(float(aw), 2), "f1": round(f1, 4)})
if f1 > best_f1:
best_f1 = f1
best_aw = float(aw)
grid_results[target_emotion] = results
optimal_weights[target_emotion] = {
"audio": round(best_aw, 2),
"text": round(1.0 - best_aw, 2),
"f1": round(best_f1, 4),
}
logger.info("%s: optimal audio_weight=%.2f (F1=%.4f)", target_emotion, best_aw, best_f1)
return grid_results, optimal_weights
def compute_overall_comparison(samples, audio_preds, text_preds, optimal_weights):
"""Compare fixed 60/40 vs optimal weights on macro F1."""
fixed_weights = {label: {"audio": 0.6, "text": 0.4} for label in PROJECT_LABELS}
y_true = [s["label"] for s in samples]
# Fixed 60/40
y_pred_fixed = []
for i in range(len(samples)):
fused = fuse_scores(audio_preds[i], text_preds[i], fixed_weights)
y_pred_fixed.append(max(fused, key=fused.get))
fixed_macro = compute_macro_f1(y_true, y_pred_fixed)
fixed_per_class = {label: compute_f1(y_true, y_pred_fixed, label) for label in PROJECT_LABELS}
# Optimal
opt_weight_dict = {e: {"audio": w["audio"], "text": w["text"]} for e, w in optimal_weights.items()}
y_pred_opt = []
for i in range(len(samples)):
fused = fuse_scores(audio_preds[i], text_preds[i], opt_weight_dict)
y_pred_opt.append(max(fused, key=fused.get))
opt_macro = compute_macro_f1(y_true, y_pred_opt)
opt_per_class = {label: compute_f1(y_true, y_pred_opt, label) for label in PROJECT_LABELS}
# Audio-only baseline
y_pred_audio = []
for i in range(len(samples)):
pred = max(audio_preds[i], key=audio_preds[i].get)
y_pred_audio.append(pred)
audio_macro = compute_macro_f1(y_true, y_pred_audio)
audio_per_class = {label: compute_f1(y_true, y_pred_audio, label) for label in PROJECT_LABELS}
return {
"audio_only": {"macro_f1": round(audio_macro, 4), "per_class": {k: round(v, 4) for k, v in audio_per_class.items()}},
"fixed_60_40": {"macro_f1": round(fixed_macro, 4), "per_class": {k: round(v, 4) for k, v in fixed_per_class.items()}},
"optimal": {"macro_f1": round(opt_macro, 4), "per_class": {k: round(v, 4) for k, v in opt_per_class.items()}},
}
def plot_grid_search(grid_results, optimal_weights, output_path: Path):
"""Plot 7 subplots: weight vs F1 per emotion."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
axes = axes.flatten()
for i, emotion in enumerate(PROJECT_LABELS):
ax = axes[i]
data = grid_results[emotion]
weights = [d["audio_weight"] for d in data]
f1s = [d["f1"] for d in data]
opt = optimal_weights[emotion]
ax.plot(weights, f1s, "b-o", markersize=3, linewidth=1.5)
ax.axvline(x=opt["audio"], color="r", linestyle="--", alpha=0.7,
label=f"optimal={opt['audio']:.2f}")
ax.axvline(x=0.6, color="gray", linestyle=":", alpha=0.5, label="fixed=0.60")
ax.set_title(f"{emotion} (best F1={opt['f1']:.3f})", fontsize=11, fontweight="bold")
ax.set_xlabel("Audio Weight")
ax.set_ylabel("F1 Score")
ax.set_xlim(-0.05, 1.05)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
# Hide last subplot (2x4 = 8, but only 7 emotions)
axes[7].set_visible(False)
fig.suptitle("Emotion-Specific Fusion Weight Grid Search", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.savefig(str(output_path), dpi=150)
plt.close()
logger.info("Grid search plot saved: %s", output_path)
def plot_comparison(comparison, optimal_weights, output_path: Path):
"""Bar chart: audio-only vs fixed 60/40 vs optimal per emotion."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
emotions = PROJECT_LABELS
audio_f1s = [comparison["audio_only"]["per_class"][e] for e in emotions]
fixed_f1s = [comparison["fixed_60_40"]["per_class"][e] for e in emotions]
opt_f1s = [comparison["optimal"]["per_class"][e] for e in emotions]
x = np.arange(len(emotions))
width = 0.25
fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width, audio_f1s, width, label=f"Audio Only (macro={comparison['audio_only']['macro_f1']:.3f})", color="#2196F3", alpha=0.8)
bars2 = ax.bar(x, fixed_f1s, width, label=f"Fixed 60/40 (macro={comparison['fixed_60_40']['macro_f1']:.3f})", color="#FF9800", alpha=0.8)
bars3 = ax.bar(x + width, opt_f1s, width, label=f"Optimal (macro={comparison['optimal']['macro_f1']:.3f})", color="#4CAF50", alpha=0.8)
# Add weight annotations on optimal bars
for i, e in enumerate(emotions):
aw = optimal_weights[e]["audio"]
ax.text(x[i] + width, opt_f1s[i] + 0.01, f"a={aw:.0%}", ha="center", fontsize=7, color="#2E7D32")
ax.set_ylabel("F1 Score")
ax.set_title("Fusion Strategy Comparison: Audio Only vs Fixed 60/40 vs Emotion-Specific Optimal", fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(emotions, fontsize=10)
ax.legend(fontsize=10)
ax.set_ylim(0, 1.0)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig(str(output_path), dpi=150)
plt.close()
logger.info("Comparison plot saved: %s", output_path)
def write_report(comparison, optimal_weights, output_path: Path):
"""Write markdown summary report."""
lines = [
"# Fusion Weight Optimization Report",
"",
"## Summary",
"",
f"| Strategy | Macro F1 |",
f"|---|---|",
f"| Audio Only | {comparison['audio_only']['macro_f1']:.4f} |",
f"| Fixed 60/40 | {comparison['fixed_60_40']['macro_f1']:.4f} |",
f"| **Emotion-Specific Optimal** | **{comparison['optimal']['macro_f1']:.4f}** |",
f"| Improvement over Fixed | **+{comparison['optimal']['macro_f1'] - comparison['fixed_60_40']['macro_f1']:.4f}** |",
"",
"## Optimal Weights Per Emotion",
"",
"| Emotion | Audio Weight | Text Weight | F1 (optimal) | F1 (fixed 60/40) | Delta |",
"|---|---|---|---|---|---|",
]
for e in PROJECT_LABELS:
aw = optimal_weights[e]["audio"]
tw = optimal_weights[e]["text"]
opt_f1 = comparison["optimal"]["per_class"][e]
fixed_f1 = comparison["fixed_60_40"]["per_class"][e]
delta = opt_f1 - fixed_f1
sign = "+" if delta >= 0 else ""
lines.append(f"| {e} | {aw:.0%} | {tw:.0%} | {opt_f1:.4f} | {fixed_f1:.4f} | {sign}{delta:.4f} |")
lines.extend([
"",
"## Methodology",
"",
"- **Data:** AI Hub 263 val split (1,294 samples, 7-class, speaker-isolated)",
"- **Audio model:** LoRA emotion2vec ONNX (7-class, macro F1=0.552)",
"- **Text model:** KcELECTRA LoRA fine-tuned (beomi/KcELECTRA-base-v2022, 7-class direct)",
"- **Search:** Per-emotion audio weight 0.0~1.0 in 0.05 steps (21 points × 7 emotions)",
"- **Metric:** Per-emotion F1 score on val set",
"",
"## Files",
"",
"- `fusion_grid_search.json` — full weight-F1 curve data",
"- `optimal_fusion_weights.json` — best weights",
"- `fusion_grid_search.png` — per-emotion weight vs F1 plots",
"- `fusion_comparison.png` — strategy comparison bar chart",
])
output_path.write_text("\n".join(lines), encoding="utf-8")
logger.info("Report saved: %s", output_path)
def predict_text_distilroberta(text: str, tokenizer, model) -> dict[str, float]:
"""Run j-hartmann/DistilRoBERTa text emotion prediction (7-class direct)."""
import torch
if not text or not text.strip():
return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS}
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).squeeze().cpu().numpy()
# DistilRoBERTa labels: anger, disgust, fear, joy, neutral, sadness, surprise
dr_labels = [model.config.id2label[i] for i in range(len(probs))]
scores = {label: 0.0 for label in PROJECT_LABELS}
for dl, prob in zip(dr_labels, probs):
if dl in scores:
scores[dl] = float(prob)
return scores
def main():
parser = argparse.ArgumentParser(description="Optimize emotion-specific fusion weights")
parser.add_argument("--lang", default="ko", choices=["ko", "en"], help="Language: ko=Korean, en=English")
parser.add_argument("--val-manifest", type=Path, default=Path("data/lora_dataset/val_manifest.json"))
parser.add_argument("--onnx-model", type=Path, default=Path("data/models/lora_emotion2vec_7class/model.onnx"))
parser.add_argument("--anchor-dir", type=Path, default=Path("data/AI Hub 263"))
parser.add_argument("--output-dir", type=Path, default=Path("data/models/fusion_optimization"))
parser.add_argument("--text-onnx", type=Path, default=Path("data/models/lora_kcelectra_7class/model.onnx"))
parser.add_argument("--text-tokenizer", default="data/models/lora_kcelectra_7class/best_model")
parser.add_argument("--en-text-model", default="j-hartmann/emotion-english-distilroberta-base")
parser.add_argument("--use-base-audio", action="store_true",
help="Use base (non-finetuned) emotion2vec via FunASR instead of LoRA ONNX")
args = parser.parse_args()
lang = args.lang
prefix = "en_" if lang == "en" else ""
output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
# Step 1: Load manifest
with open(args.val_manifest, encoding="utf-8") as f:
val_all = json.load(f)
if lang == "ko":
# Korean: 263 val only
samples = [s for s in val_all if s.get("source") == "263"]
logger.info("Korean 263 val samples: %d", len(samples))
else:
# English: MELD test (all samples have text)
samples = val_all
logger.info("English MELD test samples: %d", len(samples))
# Map label: happiness → joy for consistency
for s in samples:
if s["label"] == "happiness":
s["label"] = "joy"
# Step 2: Filter samples with text
matched = [s for s in samples if s.get("text", "").strip()]
# Korean fallback: load from CSV if no text in manifest
if not matched and lang == "ko":
logger.info("No text in manifest, loading from 263 CSVs...")
texts_map = load_263_texts(args.anchor_dir)
for s in samples:
wav_id = Path(s["path"]).stem
text = texts_map.get(wav_id, "")
if text:
s["text"] = text
matched.append(s)
logger.info("Matched audio+text: %d / %d", len(matched), len(samples))
if len(matched) < 50:
logger.error("Too few matched samples.")
sys.exit(1)
# Step 3: Load models
import onnxruntime as ort
if args.use_base_audio:
from funasr import AutoModel
logger.info("Loading base emotion2vec_plus_base via FunASR (not LoRA)...")
funasr_model = AutoModel(model="iic/emotion2vec_plus_base", device="cpu", hub="hf")
audio_predict_fn = lambda path: predict_audio_base(path, funasr_model)
else:
logger.info("Loading audio ONNX (LoRA): %s", args.onnx_model)
onnx_session = ort.InferenceSession(str(args.onnx_model), providers=["CPUExecutionProvider"])
audio_predict_fn = lambda path: predict_audio_onnx(path, onnx_session)
if lang == "ko":
from transformers import AutoTokenizer
logger.info("Loading KcELECTRA ONNX: %s", args.text_onnx)
text_session = ort.InferenceSession(str(args.text_onnx), providers=["CPUExecutionProvider"])
tokenizer = AutoTokenizer.from_pretrained(args.text_tokenizer)
text_predict_fn = lambda text: predict_text_onnx(text, tokenizer, text_session)
else:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
logger.info("Loading DistilRoBERTa: %s", args.en_text_model)
en_tokenizer = AutoTokenizer.from_pretrained(args.en_text_model)
en_model = AutoModelForSequenceClassification.from_pretrained(args.en_text_model)
en_model.eval()
text_predict_fn = lambda text: predict_text_distilroberta(text, en_tokenizer, en_model)
# Step 4: Predict all samples (with checkpoint for resume safety)
preds_cache_path = output_dir / f"{prefix}preds_cache.json"
audio_preds = []
text_preds = []
start_idx = 0
if preds_cache_path.exists():
with open(preds_cache_path) as f:
cache = json.load(f)
audio_preds = cache.get("audio_preds", [])
text_preds = cache.get("text_preds", [])
start_idx = len(audio_preds)
logger.info("Resumed from checkpoint: %d predictions already done", start_idx)
for i in range(start_idx, len(matched)):
s = matched[i]
audio_scores = audio_predict_fn(s["path"])
audio_preds.append(audio_scores)
text_scores = text_predict_fn(s["text"])
text_preds.append(text_scores)
# FunASR/PyTorch leak audio tensors across .generate() calls — force release every 25 samples
if (i + 1) % 25 == 0:
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
# Checkpoint every 100 samples
if (i + 1) % 100 == 0:
logger.info("Predicted %d / %d (saving checkpoint)", i + 1, len(matched))
with open(preds_cache_path, "w") as f:
json.dump({"audio_preds": audio_preds, "text_preds": text_preds}, f)
# Final checkpoint save
with open(preds_cache_path, "w") as f:
json.dump({"audio_preds": audio_preds, "text_preds": text_preds}, f)
logger.info("All predictions done (%d samples)", len(matched))
# Step 5: Grid search
grid_results, optimal_weights = grid_search(matched, audio_preds, text_preds)
# Step 6: Overall comparison
comparison = compute_overall_comparison(matched, audio_preds, text_preds, optimal_weights)
logger.info("Audio-only macro F1: %.4f", comparison["audio_only"]["macro_f1"])
logger.info("Fixed 60/40 macro F1: %.4f", comparison["fixed_60_40"]["macro_f1"])
logger.info("Optimal macro F1: %.4f", comparison["optimal"]["macro_f1"])
# Step 7: Save everything
with open(output_dir / f"{prefix}fusion_grid_search.json", "w") as f:
json.dump(grid_results, f, indent=2)
with open(output_dir / f"{prefix}optimal_fusion_weights.json", "w") as f:
json.dump(optimal_weights, f, indent=2, ensure_ascii=False)
with open(output_dir / f"{prefix}fusion_comparison.json", "w") as f:
json.dump(comparison, f, indent=2)
# Step 8: Plots + report
plot_grid_search(grid_results, optimal_weights, output_dir / f"{prefix}fusion_grid_search.png")
plot_comparison(comparison, optimal_weights, output_dir / f"{prefix}fusion_comparison.png")
write_report(comparison, optimal_weights, output_dir / f"{prefix}fusion_report.md")
logger.info("Done! All results saved to %s", output_dir)
if __name__ == "__main__":
main()