#!/usr/bin/env python3 """Grid Search for Emotion-Specific Fusion Weights. Uses AI Hub 263 val split (audio + text + ground truth) to find optimal audio/text fusion weights per emotion class. Outputs: - fusion_grid_search.json — full weight-F1 curves per emotion - optimal_fusion_weights.json — best weights per emotion - fusion_grid_search.png — 7 subplots: weight vs F1 per emotion - fusion_comparison.png — bar chart: fixed 60/40 vs optimal - fusion_report.md — text summary Usage: python scripts/optimize_fusion_weights.py \ --val-manifest data/lora_dataset/val_manifest.json \ --onnx-model data/models/lora_emotion2vec_7class/model.onnx \ --anchor-dir "data/AI Hub 263" \ --output-dir data/models/lora_emotion2vec_7class """ from __future__ import annotations import argparse import csv import gc import json import logging import sys from collections import Counter, defaultdict from pathlib import Path import numpy as np logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] # LoRA model labels → project labels LORA_LABELS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] LORA_TO_PROJECT = { "happiness": "joy", "anger": "anger", "disgust": "disgust", "fear": "fear", "neutral": "neutral", "sadness": "sadness", "surprise": "surprise", } # 263 label mapping (same as prepare_lora_dataset.py) MAP_263 = { "angry": "anger", "anger": "anger", "sadness": "sadness", "sad": "sadness", "happiness": "happiness", "happy": "happiness", "fear": "fear", "disgust": "disgust", "surprise": "surprise", "neutral": "neutral", } # KcELECTRA 44-class → 7-class (from src/stage2/text_emotion.py) KO_LABEL_MAP = { "기쁨": "joy", "즐거움/신남": "joy", "행복": "joy", "감동/감탄": "joy", "고마움": "joy", "환영/호의": "joy", "뿌듯함": "joy", "흐뭇함(귀여움/예쁨)": "joy", "기대감": "joy", "편안/쾌적": "joy", "안심/신뢰": "joy", "아껴주는": "joy", "존경": "joy", "놀람": "surprise", "신기함/관심": "surprise", "경악": "surprise", "어이없음": "surprise", "슬픔": "sadness", "서러움": "sadness", "안타까움/실망": "sadness", "절망": "sadness", "부끄러움": "sadness", "불쌍함/연민": "sadness", "패배/자기혐오": "sadness", "힘듦/지침": "sadness", "죄책감": "sadness", "화남/분노": "anger", "짜증": "anger", "불평/불만": "anger", "지긋지긋": "anger", "우쭐댐/무시함": "anger", "한심함": "anger", "증오/혐오": "anger", "귀찮음": "anger", "공포/무서움": "fear", "불안/걱정": "fear", "당황/난처": "fear", "의심/불신": "fear", "없음": "neutral", "깨달음": "neutral", "재미없음": "neutral", "부담/안_내킴": "neutral", "비장함": "neutral", "역겨움/징그러움": "disgust", } def load_263_texts(anchor_dir: Path) -> dict[str, str]: """Load wav_id → 발화문 mapping from 263 CSVs.""" texts = {} for csv_path in sorted(anchor_dir.glob("*.csv")): with open(csv_path, encoding="cp949") as f: reader = csv.reader(f) next(reader) # skip header for row in reader: wav_id = row[0] text = row[1] texts[wav_id] = text logger.info("Loaded %d texts from 263 CSVs", len(texts)) return texts def predict_audio_base(audio_path: str, funasr_model, max_seconds: float = 15.0) -> dict[str, float]: """Run base (non-finetuned) emotion2vec via FunASR, 9-class → 7-class mapping. Audio trimmed to max_seconds — FunASR transformer has quadratic memory in sequence length, so a 100s clip can blow past 15GB RAM. Matches predict_audio_onnx() behavior. """ # emotion2vec base native labels → project labels LABEL_MAP = { "angry": "anger", "disgusted": "disgust", "fearful": "fear", "happy": "joy", "neutral": "neutral", "sad": "sadness", "surprised": "surprise", "other": "neutral", "unknown": "neutral", "生气/angry": "anger", "厌恶/disgusted": "disgust", "恐惧/fearful": "fear", "开心/happy": "joy", "中立/neutral": "neutral", "难过/sad": "sadness", "吃惊/surprised": "surprise", "其他/other": "neutral", "": "neutral", } import soundfile as sf audio, sr = sf.read(audio_path, dtype="float32") if audio.ndim == 2: audio = audio.mean(axis=1) if sr != 16000: import librosa audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) max_samples = int(max_seconds * 16000) if len(audio) > max_samples: audio = audio[:max_samples] try: output = funasr_model.generate( audio, granularity="utterance", extract_embedding=False, ) except Exception: return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} scores = {label: 0.0 for label in PROJECT_LABELS} if output and isinstance(output, list) and len(output) > 0: rec = output[0] raw_labels = rec.get("labels", []) raw_scores = rec.get("scores", []) for native_label, score in zip(raw_labels, raw_scores): project_label = LABEL_MAP.get(native_label, "neutral") scores[project_label] += float(score) total = sum(scores.values()) if total > 0: scores = {k: v / total for k, v in scores.items()} return scores def predict_audio_onnx(audio_path: str, session, max_seconds: float = 15.0) -> dict[str, float]: """Run ONNX audio emotion prediction (trimmed to max_seconds to avoid OOM).""" import soundfile as sf audio, sr = sf.read(audio_path, dtype="float32") if audio.ndim == 2: audio = audio.mean(axis=1) if sr != 16000: import librosa audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # Trim to max_seconds to prevent OOM on very long audio max_samples = int(max_seconds * 16000) if len(audio) > max_samples: audio = audio[:max_samples] waveform = audio.reshape(1, -1).astype(np.float32) logits = session.run(None, {"waveform": waveform})[0] exp_logits = np.exp(logits - logits.max(axis=-1, keepdims=True)) probs = (exp_logits / exp_logits.sum(axis=-1, keepdims=True)).squeeze() scores = {} for lora_label, prob in zip(LORA_LABELS, probs): project_label = LORA_TO_PROJECT[lora_label] scores[project_label] = float(prob) return scores def predict_text_onnx(text: str, tokenizer, session) -> dict[str, float]: """Run fine-tuned KcELECTRA ONNX text emotion prediction (7-class direct).""" if not text or not text.strip(): return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} enc = tokenizer(text, return_tensors="np", truncation=True, max_length=128, padding="max_length") logits = session.run(None, { "input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], })[0] # Softmax exp_logits = np.exp(logits - logits.max(axis=-1, keepdims=True)) probs = (exp_logits / exp_logits.sum(axis=-1, keepdims=True)).squeeze() # LoRA KcELECTRA labels → project labels (happiness → joy) text_labels = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] text_to_project = { "happiness": "joy", "anger": "anger", "disgust": "disgust", "fear": "fear", "neutral": "neutral", "sadness": "sadness", "surprise": "surprise", } scores = {label: 0.0 for label in PROJECT_LABELS} for tl, prob in zip(text_labels, probs): pl = text_to_project[tl] scores[pl] = float(prob) return scores def fuse_scores(audio_scores, text_scores, weights): """Fuse with emotion-specific weights.""" fused = {} for label in PROJECT_LABELS: aw = weights.get(label, {}).get("audio", 0.6) tw = weights.get(label, {}).get("text", 0.4) fused[label] = audio_scores.get(label, 0.0) * aw + text_scores.get(label, 0.0) * tw total = sum(fused.values()) if total > 0: fused = {k: v / total for k, v in fused.items()} return fused def compute_f1(y_true, y_pred, target_label): """Compute F1 for a specific label (binary: target vs rest).""" tp = sum(1 for t, p in zip(y_true, y_pred) if t == target_label and p == target_label) fp = sum(1 for t, p in zip(y_true, y_pred) if t != target_label and p == target_label) fn = sum(1 for t, p in zip(y_true, y_pred) if t == target_label and p != target_label) precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 if precision + recall == 0: return 0.0 return 2 * precision * recall / (precision + recall) def compute_macro_f1(y_true, y_pred): """Compute macro F1 across all 7 classes.""" f1s = [compute_f1(y_true, y_pred, label) for label in PROJECT_LABELS] return np.mean(f1s) def grid_search(samples, audio_preds, text_preds): """Run grid search for emotion-specific weights. Returns: grid_results: dict[emotion] → list of {"audio_weight": float, "f1": float} optimal_weights: dict[emotion] → {"audio": float, "text": float, "f1": float} """ weight_range = np.arange(0.0, 1.05, 0.05) grid_results = {} optimal_weights = {} for target_emotion in PROJECT_LABELS: results = [] best_f1 = -1 best_aw = 0.6 for aw in weight_range: tw = 1.0 - aw # Build per-emotion weight dict: target emotion uses (aw, tw), others use 0.6/0.4 weights = {} for label in PROJECT_LABELS: if label == target_emotion: weights[label] = {"audio": float(aw), "text": float(tw)} else: weights[label] = {"audio": 0.6, "text": 0.4} # Predict with these weights y_true = [s["label"] for s in samples] y_pred = [] for i, s in enumerate(samples): fused = fuse_scores(audio_preds[i], text_preds[i], weights) pred = max(fused, key=fused.get) y_pred.append(pred) f1 = compute_f1(y_true, y_pred, target_emotion) results.append({"audio_weight": round(float(aw), 2), "f1": round(f1, 4)}) if f1 > best_f1: best_f1 = f1 best_aw = float(aw) grid_results[target_emotion] = results optimal_weights[target_emotion] = { "audio": round(best_aw, 2), "text": round(1.0 - best_aw, 2), "f1": round(best_f1, 4), } logger.info("%s: optimal audio_weight=%.2f (F1=%.4f)", target_emotion, best_aw, best_f1) return grid_results, optimal_weights def compute_overall_comparison(samples, audio_preds, text_preds, optimal_weights): """Compare fixed 60/40 vs optimal weights on macro F1.""" fixed_weights = {label: {"audio": 0.6, "text": 0.4} for label in PROJECT_LABELS} y_true = [s["label"] for s in samples] # Fixed 60/40 y_pred_fixed = [] for i in range(len(samples)): fused = fuse_scores(audio_preds[i], text_preds[i], fixed_weights) y_pred_fixed.append(max(fused, key=fused.get)) fixed_macro = compute_macro_f1(y_true, y_pred_fixed) fixed_per_class = {label: compute_f1(y_true, y_pred_fixed, label) for label in PROJECT_LABELS} # Optimal opt_weight_dict = {e: {"audio": w["audio"], "text": w["text"]} for e, w in optimal_weights.items()} y_pred_opt = [] for i in range(len(samples)): fused = fuse_scores(audio_preds[i], text_preds[i], opt_weight_dict) y_pred_opt.append(max(fused, key=fused.get)) opt_macro = compute_macro_f1(y_true, y_pred_opt) opt_per_class = {label: compute_f1(y_true, y_pred_opt, label) for label in PROJECT_LABELS} # Audio-only baseline y_pred_audio = [] for i in range(len(samples)): pred = max(audio_preds[i], key=audio_preds[i].get) y_pred_audio.append(pred) audio_macro = compute_macro_f1(y_true, y_pred_audio) audio_per_class = {label: compute_f1(y_true, y_pred_audio, label) for label in PROJECT_LABELS} return { "audio_only": {"macro_f1": round(audio_macro, 4), "per_class": {k: round(v, 4) for k, v in audio_per_class.items()}}, "fixed_60_40": {"macro_f1": round(fixed_macro, 4), "per_class": {k: round(v, 4) for k, v in fixed_per_class.items()}}, "optimal": {"macro_f1": round(opt_macro, 4), "per_class": {k: round(v, 4) for k, v in opt_per_class.items()}}, } def plot_grid_search(grid_results, optimal_weights, output_path: Path): """Plot 7 subplots: weight vs F1 per emotion.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 4, figsize=(18, 9)) axes = axes.flatten() for i, emotion in enumerate(PROJECT_LABELS): ax = axes[i] data = grid_results[emotion] weights = [d["audio_weight"] for d in data] f1s = [d["f1"] for d in data] opt = optimal_weights[emotion] ax.plot(weights, f1s, "b-o", markersize=3, linewidth=1.5) ax.axvline(x=opt["audio"], color="r", linestyle="--", alpha=0.7, label=f"optimal={opt['audio']:.2f}") ax.axvline(x=0.6, color="gray", linestyle=":", alpha=0.5, label="fixed=0.60") ax.set_title(f"{emotion} (best F1={opt['f1']:.3f})", fontsize=11, fontweight="bold") ax.set_xlabel("Audio Weight") ax.set_ylabel("F1 Score") ax.set_xlim(-0.05, 1.05) ax.legend(fontsize=8) ax.grid(True, alpha=0.3) # Hide last subplot (2x4 = 8, but only 7 emotions) axes[7].set_visible(False) fig.suptitle("Emotion-Specific Fusion Weight Grid Search", fontsize=14, fontweight="bold") plt.tight_layout() plt.savefig(str(output_path), dpi=150) plt.close() logger.info("Grid search plot saved: %s", output_path) def plot_comparison(comparison, optimal_weights, output_path: Path): """Bar chart: audio-only vs fixed 60/40 vs optimal per emotion.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt emotions = PROJECT_LABELS audio_f1s = [comparison["audio_only"]["per_class"][e] for e in emotions] fixed_f1s = [comparison["fixed_60_40"]["per_class"][e] for e in emotions] opt_f1s = [comparison["optimal"]["per_class"][e] for e in emotions] x = np.arange(len(emotions)) width = 0.25 fig, ax = plt.subplots(figsize=(12, 6)) bars1 = ax.bar(x - width, audio_f1s, width, label=f"Audio Only (macro={comparison['audio_only']['macro_f1']:.3f})", color="#2196F3", alpha=0.8) bars2 = ax.bar(x, fixed_f1s, width, label=f"Fixed 60/40 (macro={comparison['fixed_60_40']['macro_f1']:.3f})", color="#FF9800", alpha=0.8) bars3 = ax.bar(x + width, opt_f1s, width, label=f"Optimal (macro={comparison['optimal']['macro_f1']:.3f})", color="#4CAF50", alpha=0.8) # Add weight annotations on optimal bars for i, e in enumerate(emotions): aw = optimal_weights[e]["audio"] ax.text(x[i] + width, opt_f1s[i] + 0.01, f"a={aw:.0%}", ha="center", fontsize=7, color="#2E7D32") ax.set_ylabel("F1 Score") ax.set_title("Fusion Strategy Comparison: Audio Only vs Fixed 60/40 vs Emotion-Specific Optimal", fontweight="bold") ax.set_xticks(x) ax.set_xticklabels(emotions, fontsize=10) ax.legend(fontsize=10) ax.set_ylim(0, 1.0) ax.grid(axis="y", alpha=0.3) plt.tight_layout() plt.savefig(str(output_path), dpi=150) plt.close() logger.info("Comparison plot saved: %s", output_path) def write_report(comparison, optimal_weights, output_path: Path): """Write markdown summary report.""" lines = [ "# Fusion Weight Optimization Report", "", "## Summary", "", f"| Strategy | Macro F1 |", f"|---|---|", f"| Audio Only | {comparison['audio_only']['macro_f1']:.4f} |", f"| Fixed 60/40 | {comparison['fixed_60_40']['macro_f1']:.4f} |", f"| **Emotion-Specific Optimal** | **{comparison['optimal']['macro_f1']:.4f}** |", f"| Improvement over Fixed | **+{comparison['optimal']['macro_f1'] - comparison['fixed_60_40']['macro_f1']:.4f}** |", "", "## Optimal Weights Per Emotion", "", "| Emotion | Audio Weight | Text Weight | F1 (optimal) | F1 (fixed 60/40) | Delta |", "|---|---|---|---|---|---|", ] for e in PROJECT_LABELS: aw = optimal_weights[e]["audio"] tw = optimal_weights[e]["text"] opt_f1 = comparison["optimal"]["per_class"][e] fixed_f1 = comparison["fixed_60_40"]["per_class"][e] delta = opt_f1 - fixed_f1 sign = "+" if delta >= 0 else "" lines.append(f"| {e} | {aw:.0%} | {tw:.0%} | {opt_f1:.4f} | {fixed_f1:.4f} | {sign}{delta:.4f} |") lines.extend([ "", "## Methodology", "", "- **Data:** AI Hub 263 val split (1,294 samples, 7-class, speaker-isolated)", "- **Audio model:** LoRA emotion2vec ONNX (7-class, macro F1=0.552)", "- **Text model:** KcELECTRA LoRA fine-tuned (beomi/KcELECTRA-base-v2022, 7-class direct)", "- **Search:** Per-emotion audio weight 0.0~1.0 in 0.05 steps (21 points × 7 emotions)", "- **Metric:** Per-emotion F1 score on val set", "", "## Files", "", "- `fusion_grid_search.json` — full weight-F1 curve data", "- `optimal_fusion_weights.json` — best weights", "- `fusion_grid_search.png` — per-emotion weight vs F1 plots", "- `fusion_comparison.png` — strategy comparison bar chart", ]) output_path.write_text("\n".join(lines), encoding="utf-8") logger.info("Report saved: %s", output_path) def predict_text_distilroberta(text: str, tokenizer, model) -> dict[str, float]: """Run j-hartmann/DistilRoBERTa text emotion prediction (7-class direct).""" import torch if not text or not text.strip(): return {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1).squeeze().cpu().numpy() # DistilRoBERTa labels: anger, disgust, fear, joy, neutral, sadness, surprise dr_labels = [model.config.id2label[i] for i in range(len(probs))] scores = {label: 0.0 for label in PROJECT_LABELS} for dl, prob in zip(dr_labels, probs): if dl in scores: scores[dl] = float(prob) return scores def main(): parser = argparse.ArgumentParser(description="Optimize emotion-specific fusion weights") parser.add_argument("--lang", default="ko", choices=["ko", "en"], help="Language: ko=Korean, en=English") parser.add_argument("--val-manifest", type=Path, default=Path("data/lora_dataset/val_manifest.json")) parser.add_argument("--onnx-model", type=Path, default=Path("data/models/lora_emotion2vec_7class/model.onnx")) parser.add_argument("--anchor-dir", type=Path, default=Path("data/AI Hub 263")) parser.add_argument("--output-dir", type=Path, default=Path("data/models/fusion_optimization")) parser.add_argument("--text-onnx", type=Path, default=Path("data/models/lora_kcelectra_7class/model.onnx")) parser.add_argument("--text-tokenizer", default="data/models/lora_kcelectra_7class/best_model") parser.add_argument("--en-text-model", default="j-hartmann/emotion-english-distilroberta-base") parser.add_argument("--use-base-audio", action="store_true", help="Use base (non-finetuned) emotion2vec via FunASR instead of LoRA ONNX") args = parser.parse_args() lang = args.lang prefix = "en_" if lang == "en" else "" output_dir = args.output_dir output_dir.mkdir(parents=True, exist_ok=True) # Step 1: Load manifest with open(args.val_manifest, encoding="utf-8") as f: val_all = json.load(f) if lang == "ko": # Korean: 263 val only samples = [s for s in val_all if s.get("source") == "263"] logger.info("Korean 263 val samples: %d", len(samples)) else: # English: MELD test (all samples have text) samples = val_all logger.info("English MELD test samples: %d", len(samples)) # Map label: happiness → joy for consistency for s in samples: if s["label"] == "happiness": s["label"] = "joy" # Step 2: Filter samples with text matched = [s for s in samples if s.get("text", "").strip()] # Korean fallback: load from CSV if no text in manifest if not matched and lang == "ko": logger.info("No text in manifest, loading from 263 CSVs...") texts_map = load_263_texts(args.anchor_dir) for s in samples: wav_id = Path(s["path"]).stem text = texts_map.get(wav_id, "") if text: s["text"] = text matched.append(s) logger.info("Matched audio+text: %d / %d", len(matched), len(samples)) if len(matched) < 50: logger.error("Too few matched samples.") sys.exit(1) # Step 3: Load models import onnxruntime as ort if args.use_base_audio: from funasr import AutoModel logger.info("Loading base emotion2vec_plus_base via FunASR (not LoRA)...") funasr_model = AutoModel(model="iic/emotion2vec_plus_base", device="cpu", hub="hf") audio_predict_fn = lambda path: predict_audio_base(path, funasr_model) else: logger.info("Loading audio ONNX (LoRA): %s", args.onnx_model) onnx_session = ort.InferenceSession(str(args.onnx_model), providers=["CPUExecutionProvider"]) audio_predict_fn = lambda path: predict_audio_onnx(path, onnx_session) if lang == "ko": from transformers import AutoTokenizer logger.info("Loading KcELECTRA ONNX: %s", args.text_onnx) text_session = ort.InferenceSession(str(args.text_onnx), providers=["CPUExecutionProvider"]) tokenizer = AutoTokenizer.from_pretrained(args.text_tokenizer) text_predict_fn = lambda text: predict_text_onnx(text, tokenizer, text_session) else: from transformers import AutoTokenizer, AutoModelForSequenceClassification logger.info("Loading DistilRoBERTa: %s", args.en_text_model) en_tokenizer = AutoTokenizer.from_pretrained(args.en_text_model) en_model = AutoModelForSequenceClassification.from_pretrained(args.en_text_model) en_model.eval() text_predict_fn = lambda text: predict_text_distilroberta(text, en_tokenizer, en_model) # Step 4: Predict all samples (with checkpoint for resume safety) preds_cache_path = output_dir / f"{prefix}preds_cache.json" audio_preds = [] text_preds = [] start_idx = 0 if preds_cache_path.exists(): with open(preds_cache_path) as f: cache = json.load(f) audio_preds = cache.get("audio_preds", []) text_preds = cache.get("text_preds", []) start_idx = len(audio_preds) logger.info("Resumed from checkpoint: %d predictions already done", start_idx) for i in range(start_idx, len(matched)): s = matched[i] audio_scores = audio_predict_fn(s["path"]) audio_preds.append(audio_scores) text_scores = text_predict_fn(s["text"]) text_preds.append(text_scores) # FunASR/PyTorch leak audio tensors across .generate() calls — force release every 25 samples if (i + 1) % 25 == 0: gc.collect() try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass # Checkpoint every 100 samples if (i + 1) % 100 == 0: logger.info("Predicted %d / %d (saving checkpoint)", i + 1, len(matched)) with open(preds_cache_path, "w") as f: json.dump({"audio_preds": audio_preds, "text_preds": text_preds}, f) # Final checkpoint save with open(preds_cache_path, "w") as f: json.dump({"audio_preds": audio_preds, "text_preds": text_preds}, f) logger.info("All predictions done (%d samples)", len(matched)) # Step 5: Grid search grid_results, optimal_weights = grid_search(matched, audio_preds, text_preds) # Step 6: Overall comparison comparison = compute_overall_comparison(matched, audio_preds, text_preds, optimal_weights) logger.info("Audio-only macro F1: %.4f", comparison["audio_only"]["macro_f1"]) logger.info("Fixed 60/40 macro F1: %.4f", comparison["fixed_60_40"]["macro_f1"]) logger.info("Optimal macro F1: %.4f", comparison["optimal"]["macro_f1"]) # Step 7: Save everything with open(output_dir / f"{prefix}fusion_grid_search.json", "w") as f: json.dump(grid_results, f, indent=2) with open(output_dir / f"{prefix}optimal_fusion_weights.json", "w") as f: json.dump(optimal_weights, f, indent=2, ensure_ascii=False) with open(output_dir / f"{prefix}fusion_comparison.json", "w") as f: json.dump(comparison, f, indent=2) # Step 8: Plots + report plot_grid_search(grid_results, optimal_weights, output_dir / f"{prefix}fusion_grid_search.png") plot_comparison(comparison, optimal_weights, output_dir / f"{prefix}fusion_comparison.png") write_report(comparison, optimal_weights, output_dir / f"{prefix}fusion_report.md") logger.info("Done! All results saved to %s", output_dir) if __name__ == "__main__": main()