| |
| """Grid Search for Emotion-Specific Fusion Weights. |
| |
| Uses AI Hub 263 val split (audio + text + ground truth) to find optimal |
| audio/text fusion weights per emotion class. |
| |
| Outputs: |
| - fusion_grid_search.json — full weight-F1 curves per emotion |
| - optimal_fusion_weights.json — best weights per emotion |
| - fusion_grid_search.png — 7 subplots: weight vs F1 per emotion |
| - fusion_comparison.png — bar chart: fixed 60/40 vs optimal |
| - fusion_report.md — text summary |
| |
| Usage: |
| python scripts/optimize_fusion_weights.py \ |
| --val-manifest data/lora_dataset/val_manifest.json \ |
| --onnx-model data/models/lora_emotion2vec_7class/model.onnx \ |
| --anchor-dir "data/AI Hub 263" \ |
| --output-dir data/models/lora_emotion2vec_7class |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import gc |
| import json |
| import logging |
| import sys |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] |
|
|
| |
| LORA_LABELS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] |
| LORA_TO_PROJECT = { |
| "happiness": "joy", "anger": "anger", "disgust": "disgust", |
| "fear": "fear", "neutral": "neutral", "sadness": "sadness", "surprise": "surprise", |
| } |
|
|
| |
| MAP_263 = { |
| "angry": "anger", "anger": "anger", |
| "sadness": "sadness", "sad": "sadness", |
| "happiness": "happiness", "happy": "happiness", |
| "fear": "fear", "disgust": "disgust", |
| "surprise": "surprise", "neutral": "neutral", |
| } |
|
|
| |
| KO_LABEL_MAP = { |
| "기쁨": "joy", "즐거움/신남": "joy", "행복": "joy", |
| "감동/감탄": "joy", "고마움": "joy", "환영/호의": "joy", |
| "뿌듯함": "joy", "흐뭇함(귀여움/예쁨)": "joy", "기대감": "joy", |
| "편안/쾌적": "joy", "안심/신뢰": "joy", "아껴주는": "joy", "존경": "joy", |
| "놀람": "surprise", "신기함/관심": "surprise", "경악": "surprise", "어이없음": "surprise", |
| "슬픔": "sadness", "서러움": "sadness", "안타까움/실망": "sadness", |
| "절망": "sadness", "부끄러움": "sadness", "불쌍함/연민": "sadness", |
| "패배/자기혐오": "sadness", "힘듦/지침": "sadness", "죄책감": "sadness", |
| "화남/분노": "anger", "짜증": "anger", "불평/불만": "anger", |
| "지긋지긋": "anger", "우쭐댐/무시함": "anger", "한심함": "anger", |
| "증오/혐오": "anger", "귀찮음": "anger", |
| "공포/무서움": "fear", "불안/걱정": "fear", "당황/난처": "fear", "의심/불신": "fear", |
| "없음": "neutral", "깨달음": "neutral", "재미없음": "neutral", |
| "부담/안_내킴": "neutral", "비장함": "neutral", |
| "역겨움/징그러움": "disgust", |
| } |
|
|
|
|
| def load_263_texts(anchor_dir: Path) -> dict[str, str]: |
| """Load wav_id → 발화문 mapping from 263 CSVs.""" |
| texts = {} |
| for csv_path in sorted(anchor_dir.glob("*.csv")): |
| with open(csv_path, encoding="cp949") as f: |
| reader = csv.reader(f) |
| next(reader) |
| for row in reader: |
| wav_id = row[0] |
| text = row[1] |
| texts[wav_id] = text |
| logger.info("Loaded %d texts from 263 CSVs", len(texts)) |
| return texts |
|
|
|
|
| def predict_audio_base(audio_path: str, funasr_model, max_seconds: float = 15.0) -> dict[str, float]: |
| """Run base (non-finetuned) emotion2vec via FunASR, 9-class → 7-class mapping. |
| |
| Audio trimmed to max_seconds — FunASR transformer has quadratic memory in sequence length, |
| so a 100s clip can blow past 15GB RAM. Matches predict_audio_onnx() behavior. |
| """ |
| |
| LABEL_MAP = { |
| "angry": "anger", "disgusted": "disgust", "fearful": "fear", |
| "happy": "joy", "neutral": "neutral", "sad": "sadness", "surprised": "surprise", |
| "other": "neutral", "unknown": "neutral", |
| "生气/angry": "anger", "厌恶/disgusted": "disgust", "恐惧/fearful": "fear", |
| "开心/happy": "joy", "中立/neutral": "neutral", "难过/sad": "sadness", |
| "吃惊/surprised": "surprise", "其他/other": "neutral", "<unk>": "neutral", |
| } |
|
|
| import soundfile as sf |
|
|
| audio, sr = sf.read(audio_path, dtype="float32") |
| if audio.ndim == 2: |
| audio = audio.mean(axis=1) |
| if sr != 16000: |
| import librosa |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) |
|
|
| max_samples = int(max_seconds * 16000) |
| if len(audio) > max_samples: |
| audio = audio[:max_samples] |
|
|
| try: |
| output = funasr_model.generate( |
| audio, granularity="utterance", extract_embedding=False, |
| ) |
| except Exception: |
| return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} |
|
|
| scores = {label: 0.0 for label in PROJECT_LABELS} |
| if output and isinstance(output, list) and len(output) > 0: |
| rec = output[0] |
| raw_labels = rec.get("labels", []) |
| raw_scores = rec.get("scores", []) |
| for native_label, score in zip(raw_labels, raw_scores): |
| project_label = LABEL_MAP.get(native_label, "neutral") |
| scores[project_label] += float(score) |
|
|
| total = sum(scores.values()) |
| if total > 0: |
| scores = {k: v / total for k, v in scores.items()} |
| return scores |
|
|
|
|
| def predict_audio_onnx(audio_path: str, session, max_seconds: float = 15.0) -> dict[str, float]: |
| """Run ONNX audio emotion prediction (trimmed to max_seconds to avoid OOM).""" |
| import soundfile as sf |
|
|
| audio, sr = sf.read(audio_path, dtype="float32") |
| if audio.ndim == 2: |
| audio = audio.mean(axis=1) |
| if sr != 16000: |
| import librosa |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) |
|
|
| |
| max_samples = int(max_seconds * 16000) |
| if len(audio) > max_samples: |
| audio = audio[:max_samples] |
|
|
| waveform = audio.reshape(1, -1).astype(np.float32) |
| logits = session.run(None, {"waveform": waveform})[0] |
|
|
| exp_logits = np.exp(logits - logits.max(axis=-1, keepdims=True)) |
| probs = (exp_logits / exp_logits.sum(axis=-1, keepdims=True)).squeeze() |
|
|
| scores = {} |
| for lora_label, prob in zip(LORA_LABELS, probs): |
| project_label = LORA_TO_PROJECT[lora_label] |
| scores[project_label] = float(prob) |
| return scores |
|
|
|
|
| def predict_text_onnx(text: str, tokenizer, session) -> dict[str, float]: |
| """Run fine-tuned KcELECTRA ONNX text emotion prediction (7-class direct).""" |
| if not text or not text.strip(): |
| return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} |
|
|
| enc = tokenizer(text, return_tensors="np", truncation=True, max_length=128, padding="max_length") |
| logits = session.run(None, { |
| "input_ids": enc["input_ids"], |
| "attention_mask": enc["attention_mask"], |
| })[0] |
|
|
| |
| exp_logits = np.exp(logits - logits.max(axis=-1, keepdims=True)) |
| probs = (exp_logits / exp_logits.sum(axis=-1, keepdims=True)).squeeze() |
|
|
| |
| text_labels = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] |
| text_to_project = { |
| "happiness": "joy", "anger": "anger", "disgust": "disgust", |
| "fear": "fear", "neutral": "neutral", "sadness": "sadness", "surprise": "surprise", |
| } |
| scores = {label: 0.0 for label in PROJECT_LABELS} |
| for tl, prob in zip(text_labels, probs): |
| pl = text_to_project[tl] |
| scores[pl] = float(prob) |
| return scores |
|
|
|
|
| def fuse_scores(audio_scores, text_scores, weights): |
| """Fuse with emotion-specific weights.""" |
| fused = {} |
| for label in PROJECT_LABELS: |
| aw = weights.get(label, {}).get("audio", 0.6) |
| tw = weights.get(label, {}).get("text", 0.4) |
| fused[label] = audio_scores.get(label, 0.0) * aw + text_scores.get(label, 0.0) * tw |
|
|
| total = sum(fused.values()) |
| if total > 0: |
| fused = {k: v / total for k, v in fused.items()} |
| return fused |
|
|
|
|
| def compute_f1(y_true, y_pred, target_label): |
| """Compute F1 for a specific label (binary: target vs rest).""" |
| tp = sum(1 for t, p in zip(y_true, y_pred) if t == target_label and p == target_label) |
| fp = sum(1 for t, p in zip(y_true, y_pred) if t != target_label and p == target_label) |
| fn = sum(1 for t, p in zip(y_true, y_pred) if t == target_label and p != target_label) |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
| if precision + recall == 0: |
| return 0.0 |
| return 2 * precision * recall / (precision + recall) |
|
|
|
|
| def compute_macro_f1(y_true, y_pred): |
| """Compute macro F1 across all 7 classes.""" |
| f1s = [compute_f1(y_true, y_pred, label) for label in PROJECT_LABELS] |
| return np.mean(f1s) |
|
|
|
|
| def grid_search(samples, audio_preds, text_preds): |
| """Run grid search for emotion-specific weights. |
| |
| Returns: |
| grid_results: dict[emotion] → list of {"audio_weight": float, "f1": float} |
| optimal_weights: dict[emotion] → {"audio": float, "text": float, "f1": float} |
| """ |
| weight_range = np.arange(0.0, 1.05, 0.05) |
| grid_results = {} |
| optimal_weights = {} |
|
|
| for target_emotion in PROJECT_LABELS: |
| results = [] |
| best_f1 = -1 |
| best_aw = 0.6 |
|
|
| for aw in weight_range: |
| tw = 1.0 - aw |
| |
| weights = {} |
| for label in PROJECT_LABELS: |
| if label == target_emotion: |
| weights[label] = {"audio": float(aw), "text": float(tw)} |
| else: |
| weights[label] = {"audio": 0.6, "text": 0.4} |
|
|
| |
| y_true = [s["label"] for s in samples] |
| y_pred = [] |
| for i, s in enumerate(samples): |
| fused = fuse_scores(audio_preds[i], text_preds[i], weights) |
| pred = max(fused, key=fused.get) |
| y_pred.append(pred) |
|
|
| f1 = compute_f1(y_true, y_pred, target_emotion) |
| results.append({"audio_weight": round(float(aw), 2), "f1": round(f1, 4)}) |
|
|
| if f1 > best_f1: |
| best_f1 = f1 |
| best_aw = float(aw) |
|
|
| grid_results[target_emotion] = results |
| optimal_weights[target_emotion] = { |
| "audio": round(best_aw, 2), |
| "text": round(1.0 - best_aw, 2), |
| "f1": round(best_f1, 4), |
| } |
| logger.info("%s: optimal audio_weight=%.2f (F1=%.4f)", target_emotion, best_aw, best_f1) |
|
|
| return grid_results, optimal_weights |
|
|
|
|
| def compute_overall_comparison(samples, audio_preds, text_preds, optimal_weights): |
| """Compare fixed 60/40 vs optimal weights on macro F1.""" |
| fixed_weights = {label: {"audio": 0.6, "text": 0.4} for label in PROJECT_LABELS} |
| y_true = [s["label"] for s in samples] |
|
|
| |
| y_pred_fixed = [] |
| for i in range(len(samples)): |
| fused = fuse_scores(audio_preds[i], text_preds[i], fixed_weights) |
| y_pred_fixed.append(max(fused, key=fused.get)) |
| fixed_macro = compute_macro_f1(y_true, y_pred_fixed) |
| fixed_per_class = {label: compute_f1(y_true, y_pred_fixed, label) for label in PROJECT_LABELS} |
|
|
| |
| opt_weight_dict = {e: {"audio": w["audio"], "text": w["text"]} for e, w in optimal_weights.items()} |
| y_pred_opt = [] |
| for i in range(len(samples)): |
| fused = fuse_scores(audio_preds[i], text_preds[i], opt_weight_dict) |
| y_pred_opt.append(max(fused, key=fused.get)) |
| opt_macro = compute_macro_f1(y_true, y_pred_opt) |
| opt_per_class = {label: compute_f1(y_true, y_pred_opt, label) for label in PROJECT_LABELS} |
|
|
| |
| y_pred_audio = [] |
| for i in range(len(samples)): |
| pred = max(audio_preds[i], key=audio_preds[i].get) |
| y_pred_audio.append(pred) |
| audio_macro = compute_macro_f1(y_true, y_pred_audio) |
| audio_per_class = {label: compute_f1(y_true, y_pred_audio, label) for label in PROJECT_LABELS} |
|
|
| return { |
| "audio_only": {"macro_f1": round(audio_macro, 4), "per_class": {k: round(v, 4) for k, v in audio_per_class.items()}}, |
| "fixed_60_40": {"macro_f1": round(fixed_macro, 4), "per_class": {k: round(v, 4) for k, v in fixed_per_class.items()}}, |
| "optimal": {"macro_f1": round(opt_macro, 4), "per_class": {k: round(v, 4) for k, v in opt_per_class.items()}}, |
| } |
|
|
|
|
| def plot_grid_search(grid_results, optimal_weights, output_path: Path): |
| """Plot 7 subplots: weight vs F1 per emotion.""" |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| fig, axes = plt.subplots(2, 4, figsize=(18, 9)) |
| axes = axes.flatten() |
|
|
| for i, emotion in enumerate(PROJECT_LABELS): |
| ax = axes[i] |
| data = grid_results[emotion] |
| weights = [d["audio_weight"] for d in data] |
| f1s = [d["f1"] for d in data] |
| opt = optimal_weights[emotion] |
|
|
| ax.plot(weights, f1s, "b-o", markersize=3, linewidth=1.5) |
| ax.axvline(x=opt["audio"], color="r", linestyle="--", alpha=0.7, |
| label=f"optimal={opt['audio']:.2f}") |
| ax.axvline(x=0.6, color="gray", linestyle=":", alpha=0.5, label="fixed=0.60") |
| ax.set_title(f"{emotion} (best F1={opt['f1']:.3f})", fontsize=11, fontweight="bold") |
| ax.set_xlabel("Audio Weight") |
| ax.set_ylabel("F1 Score") |
| ax.set_xlim(-0.05, 1.05) |
| ax.legend(fontsize=8) |
| ax.grid(True, alpha=0.3) |
|
|
| |
| axes[7].set_visible(False) |
|
|
| fig.suptitle("Emotion-Specific Fusion Weight Grid Search", fontsize=14, fontweight="bold") |
| plt.tight_layout() |
| plt.savefig(str(output_path), dpi=150) |
| plt.close() |
| logger.info("Grid search plot saved: %s", output_path) |
|
|
|
|
| def plot_comparison(comparison, optimal_weights, output_path: Path): |
| """Bar chart: audio-only vs fixed 60/40 vs optimal per emotion.""" |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| emotions = PROJECT_LABELS |
| audio_f1s = [comparison["audio_only"]["per_class"][e] for e in emotions] |
| fixed_f1s = [comparison["fixed_60_40"]["per_class"][e] for e in emotions] |
| opt_f1s = [comparison["optimal"]["per_class"][e] for e in emotions] |
|
|
| x = np.arange(len(emotions)) |
| width = 0.25 |
|
|
| fig, ax = plt.subplots(figsize=(12, 6)) |
| bars1 = ax.bar(x - width, audio_f1s, width, label=f"Audio Only (macro={comparison['audio_only']['macro_f1']:.3f})", color="#2196F3", alpha=0.8) |
| bars2 = ax.bar(x, fixed_f1s, width, label=f"Fixed 60/40 (macro={comparison['fixed_60_40']['macro_f1']:.3f})", color="#FF9800", alpha=0.8) |
| bars3 = ax.bar(x + width, opt_f1s, width, label=f"Optimal (macro={comparison['optimal']['macro_f1']:.3f})", color="#4CAF50", alpha=0.8) |
|
|
| |
| for i, e in enumerate(emotions): |
| aw = optimal_weights[e]["audio"] |
| ax.text(x[i] + width, opt_f1s[i] + 0.01, f"a={aw:.0%}", ha="center", fontsize=7, color="#2E7D32") |
|
|
| ax.set_ylabel("F1 Score") |
| ax.set_title("Fusion Strategy Comparison: Audio Only vs Fixed 60/40 vs Emotion-Specific Optimal", fontweight="bold") |
| ax.set_xticks(x) |
| ax.set_xticklabels(emotions, fontsize=10) |
| ax.legend(fontsize=10) |
| ax.set_ylim(0, 1.0) |
| ax.grid(axis="y", alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig(str(output_path), dpi=150) |
| plt.close() |
| logger.info("Comparison plot saved: %s", output_path) |
|
|
|
|
| def write_report(comparison, optimal_weights, output_path: Path): |
| """Write markdown summary report.""" |
| lines = [ |
| "# Fusion Weight Optimization Report", |
| "", |
| "## Summary", |
| "", |
| f"| Strategy | Macro F1 |", |
| f"|---|---|", |
| f"| Audio Only | {comparison['audio_only']['macro_f1']:.4f} |", |
| f"| Fixed 60/40 | {comparison['fixed_60_40']['macro_f1']:.4f} |", |
| f"| **Emotion-Specific Optimal** | **{comparison['optimal']['macro_f1']:.4f}** |", |
| f"| Improvement over Fixed | **+{comparison['optimal']['macro_f1'] - comparison['fixed_60_40']['macro_f1']:.4f}** |", |
| "", |
| "## Optimal Weights Per Emotion", |
| "", |
| "| Emotion | Audio Weight | Text Weight | F1 (optimal) | F1 (fixed 60/40) | Delta |", |
| "|---|---|---|---|---|---|", |
| ] |
| for e in PROJECT_LABELS: |
| aw = optimal_weights[e]["audio"] |
| tw = optimal_weights[e]["text"] |
| opt_f1 = comparison["optimal"]["per_class"][e] |
| fixed_f1 = comparison["fixed_60_40"]["per_class"][e] |
| delta = opt_f1 - fixed_f1 |
| sign = "+" if delta >= 0 else "" |
| lines.append(f"| {e} | {aw:.0%} | {tw:.0%} | {opt_f1:.4f} | {fixed_f1:.4f} | {sign}{delta:.4f} |") |
|
|
| lines.extend([ |
| "", |
| "## Methodology", |
| "", |
| "- **Data:** AI Hub 263 val split (1,294 samples, 7-class, speaker-isolated)", |
| "- **Audio model:** LoRA emotion2vec ONNX (7-class, macro F1=0.552)", |
| "- **Text model:** KcELECTRA LoRA fine-tuned (beomi/KcELECTRA-base-v2022, 7-class direct)", |
| "- **Search:** Per-emotion audio weight 0.0~1.0 in 0.05 steps (21 points × 7 emotions)", |
| "- **Metric:** Per-emotion F1 score on val set", |
| "", |
| "## Files", |
| "", |
| "- `fusion_grid_search.json` — full weight-F1 curve data", |
| "- `optimal_fusion_weights.json` — best weights", |
| "- `fusion_grid_search.png` — per-emotion weight vs F1 plots", |
| "- `fusion_comparison.png` — strategy comparison bar chart", |
| ]) |
|
|
| output_path.write_text("\n".join(lines), encoding="utf-8") |
| logger.info("Report saved: %s", output_path) |
|
|
|
|
| def predict_text_distilroberta(text: str, tokenizer, model) -> dict[str, float]: |
| """Run j-hartmann/DistilRoBERTa text emotion prediction (7-class direct).""" |
| import torch |
|
|
| if not text or not text.strip(): |
| return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} |
|
|
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.softmax(outputs.logits, dim=-1).squeeze().cpu().numpy() |
|
|
| |
| dr_labels = [model.config.id2label[i] for i in range(len(probs))] |
| scores = {label: 0.0 for label in PROJECT_LABELS} |
| for dl, prob in zip(dr_labels, probs): |
| if dl in scores: |
| scores[dl] = float(prob) |
| return scores |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Optimize emotion-specific fusion weights") |
| parser.add_argument("--lang", default="ko", choices=["ko", "en"], help="Language: ko=Korean, en=English") |
| parser.add_argument("--val-manifest", type=Path, default=Path("data/lora_dataset/val_manifest.json")) |
| parser.add_argument("--onnx-model", type=Path, default=Path("data/models/lora_emotion2vec_7class/model.onnx")) |
| parser.add_argument("--anchor-dir", type=Path, default=Path("data/AI Hub 263")) |
| parser.add_argument("--output-dir", type=Path, default=Path("data/models/fusion_optimization")) |
| parser.add_argument("--text-onnx", type=Path, default=Path("data/models/lora_kcelectra_7class/model.onnx")) |
| parser.add_argument("--text-tokenizer", default="data/models/lora_kcelectra_7class/best_model") |
| parser.add_argument("--en-text-model", default="j-hartmann/emotion-english-distilroberta-base") |
| parser.add_argument("--use-base-audio", action="store_true", |
| help="Use base (non-finetuned) emotion2vec via FunASR instead of LoRA ONNX") |
| args = parser.parse_args() |
|
|
| lang = args.lang |
| prefix = "en_" if lang == "en" else "" |
| output_dir = args.output_dir |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| with open(args.val_manifest, encoding="utf-8") as f: |
| val_all = json.load(f) |
|
|
| if lang == "ko": |
| |
| samples = [s for s in val_all if s.get("source") == "263"] |
| logger.info("Korean 263 val samples: %d", len(samples)) |
| else: |
| |
| samples = val_all |
| logger.info("English MELD test samples: %d", len(samples)) |
|
|
| |
| for s in samples: |
| if s["label"] == "happiness": |
| s["label"] = "joy" |
|
|
| |
| matched = [s for s in samples if s.get("text", "").strip()] |
|
|
| |
| if not matched and lang == "ko": |
| logger.info("No text in manifest, loading from 263 CSVs...") |
| texts_map = load_263_texts(args.anchor_dir) |
| for s in samples: |
| wav_id = Path(s["path"]).stem |
| text = texts_map.get(wav_id, "") |
| if text: |
| s["text"] = text |
| matched.append(s) |
|
|
| logger.info("Matched audio+text: %d / %d", len(matched), len(samples)) |
| if len(matched) < 50: |
| logger.error("Too few matched samples.") |
| sys.exit(1) |
|
|
| |
| import onnxruntime as ort |
|
|
| if args.use_base_audio: |
| from funasr import AutoModel |
| logger.info("Loading base emotion2vec_plus_base via FunASR (not LoRA)...") |
| funasr_model = AutoModel(model="iic/emotion2vec_plus_base", device="cpu", hub="hf") |
| audio_predict_fn = lambda path: predict_audio_base(path, funasr_model) |
| else: |
| logger.info("Loading audio ONNX (LoRA): %s", args.onnx_model) |
| onnx_session = ort.InferenceSession(str(args.onnx_model), providers=["CPUExecutionProvider"]) |
| audio_predict_fn = lambda path: predict_audio_onnx(path, onnx_session) |
|
|
| if lang == "ko": |
| from transformers import AutoTokenizer |
| logger.info("Loading KcELECTRA ONNX: %s", args.text_onnx) |
| text_session = ort.InferenceSession(str(args.text_onnx), providers=["CPUExecutionProvider"]) |
| tokenizer = AutoTokenizer.from_pretrained(args.text_tokenizer) |
| text_predict_fn = lambda text: predict_text_onnx(text, tokenizer, text_session) |
| else: |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| logger.info("Loading DistilRoBERTa: %s", args.en_text_model) |
| en_tokenizer = AutoTokenizer.from_pretrained(args.en_text_model) |
| en_model = AutoModelForSequenceClassification.from_pretrained(args.en_text_model) |
| en_model.eval() |
| text_predict_fn = lambda text: predict_text_distilroberta(text, en_tokenizer, en_model) |
|
|
| |
| preds_cache_path = output_dir / f"{prefix}preds_cache.json" |
| audio_preds = [] |
| text_preds = [] |
| start_idx = 0 |
|
|
| if preds_cache_path.exists(): |
| with open(preds_cache_path) as f: |
| cache = json.load(f) |
| audio_preds = cache.get("audio_preds", []) |
| text_preds = cache.get("text_preds", []) |
| start_idx = len(audio_preds) |
| logger.info("Resumed from checkpoint: %d predictions already done", start_idx) |
|
|
| for i in range(start_idx, len(matched)): |
| s = matched[i] |
| audio_scores = audio_predict_fn(s["path"]) |
| audio_preds.append(audio_scores) |
|
|
| text_scores = text_predict_fn(s["text"]) |
| text_preds.append(text_scores) |
|
|
| |
| if (i + 1) % 25 == 0: |
| gc.collect() |
| try: |
| import torch |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| except ImportError: |
| pass |
|
|
| |
| if (i + 1) % 100 == 0: |
| logger.info("Predicted %d / %d (saving checkpoint)", i + 1, len(matched)) |
| with open(preds_cache_path, "w") as f: |
| json.dump({"audio_preds": audio_preds, "text_preds": text_preds}, f) |
|
|
| |
| with open(preds_cache_path, "w") as f: |
| json.dump({"audio_preds": audio_preds, "text_preds": text_preds}, f) |
|
|
| logger.info("All predictions done (%d samples)", len(matched)) |
|
|
| |
| grid_results, optimal_weights = grid_search(matched, audio_preds, text_preds) |
|
|
| |
| comparison = compute_overall_comparison(matched, audio_preds, text_preds, optimal_weights) |
| logger.info("Audio-only macro F1: %.4f", comparison["audio_only"]["macro_f1"]) |
| logger.info("Fixed 60/40 macro F1: %.4f", comparison["fixed_60_40"]["macro_f1"]) |
| logger.info("Optimal macro F1: %.4f", comparison["optimal"]["macro_f1"]) |
|
|
| |
| with open(output_dir / f"{prefix}fusion_grid_search.json", "w") as f: |
| json.dump(grid_results, f, indent=2) |
| with open(output_dir / f"{prefix}optimal_fusion_weights.json", "w") as f: |
| json.dump(optimal_weights, f, indent=2, ensure_ascii=False) |
| with open(output_dir / f"{prefix}fusion_comparison.json", "w") as f: |
| json.dump(comparison, f, indent=2) |
|
|
| |
| plot_grid_search(grid_results, optimal_weights, output_dir / f"{prefix}fusion_grid_search.png") |
| plot_comparison(comparison, optimal_weights, output_dir / f"{prefix}fusion_comparison.png") |
| write_report(comparison, optimal_weights, output_dir / f"{prefix}fusion_report.md") |
|
|
| logger.info("Done! All results saved to %s", output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|