""" compare_models.py — Vẽ biểu đồ so sánh 5 variant sau khi training xong. Cách dùng: python scripts/compare_models.py # auto-tìm tất cả history python scripts/compare_models.py --log_dir logs/history # chỉ định thư mục python scripts/compare_models.py --out results/charts # thư mục lưu chart Tự động tìm file history.json theo pattern: logs/history/{VARIANT}/{timestamp}/history.json """ import argparse import json import os import glob from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.ticker as mticker import numpy as np # ─── Cấu hình ──────────────────────────────────────────────────────────────── VARIANTS = ["A1", "A2", "B1", "B2", "DPO", "PPO"] COLORS = { "A1": "#2ecc71", # xanh lá "A2": "#3498db", # xanh dương "B1": "#e67e22", # cam "B2": "#9b59b6", # tím "DPO": "#e74c3c", # đỏ "PPO": "#1abc9c", # xanh ngoc } MARKERS = { "A1": "o", "A2": "s", "B1": "^", "B2": "D", "DPO": "P", "PPO": "X" } METRICS_LABELS = { "val_accuracy_normalized": "Accuracy", "val_f1_normalized": "F1 Score", "val_bleu4_normalized": "BLEU-4", "val_bert_score_raw": "BERTScore", "val_semantic_raw": "Semantic Score", "val_closed_accuracy": "Closed Accuracy", "val_closed_em": "Closed EM", "val_closed_f1": "Closed F1", "val_open_semantic": "Open Semantic", "val_open_bertscore": "Open BERTScore", "val_open_f1": "Open F1", "val_open_rouge_l": "Open ROUGE-L", "train_loss": "Train Loss", } # ─── Helpers ────────────────────────────────────────────────────────────────── def find_latest_history(log_dir: str, variant: str) -> dict | None: """ Tìm file history.json mới nhất cho một variant. Hỗ trợ cả 2 format: • logs/history/{VARIANT}/{timestamp}/history.json (MedicalVQATrainer) • logs/history/{VARIANT}/history.json (flat) """ patterns = [ os.path.join(log_dir, variant, "**", "history.json"), os.path.join(log_dir, variant, "history.json"), os.path.join(log_dir, "**", variant, "**", "history.json"), ] found = [] for pat in patterns: found.extend(glob.glob(pat, recursive=True)) if not found: return None # Lấy file mới nhất theo mtime latest = max(found, key=os.path.getmtime) try: with open(latest, "r", encoding="utf-8") as f: data = json.load(f) print(f"[✓] {variant}: {latest} ({len(data)} records)") return {"path": latest, "records": data} except Exception as e: print(f"[✗] {variant}: đọc thất bại — {e}") return None def extract_series(records: list, key: str) -> tuple[list, list]: """Trích xuất (epochs, values) từ list records.""" nested_metric_map = { "val_closed_accuracy": ("closed", "accuracy_normalized", "accuracy"), "val_closed_em": ("closed", "em_normalized", "em"), "val_closed_f1": ("closed", "f1_normalized", "f1"), "val_open_semantic": ("open", "semantic_raw", "semantic"), "val_open_bertscore": ("open", "bert_score_raw", "bert_score"), "val_open_f1": ("open", "f1_normalized", "f1"), "val_open_rouge_l": ("open", "rouge_l_normalized", "rouge_l"), } epochs, values = [], [] for r in records: # Hỗ trợ cả HuggingFace log format (có 'epoch' float) và MedicalVQATrainer format epoch = r.get("epoch") if epoch is None: continue val = r.get(key) if val is None: # Thử alias cho HF SFTTrainer/DPOTrainer logs aliases = { "val_accuracy_normalized": ["eval_accuracy", "eval_vqa_accuracy"], "val_f1_normalized": ["eval_f1"], "val_bleu4_normalized": ["eval_bleu4", "eval_bleu"], "val_bert_score_raw": ["eval_bertscore", "eval_bert_score"], "val_semantic_raw": ["eval_semantic"], "val_closed_accuracy": ["eval_closed_accuracy"], "val_closed_em": ["eval_closed_em"], "val_closed_f1": ["eval_closed_f1"], "val_open_semantic": ["eval_open_semantic"], "val_open_bertscore": ["eval_open_bertscore"], "val_open_f1": ["eval_open_f1"], "val_open_rouge_l": ["eval_open_rouge_l"], "train_loss": ["loss", "train/loss"], } for alias in aliases.get(key, []): val = r.get(alias) if val is not None: break if val is None and key in nested_metric_map: split_key, primary_key, fallback_key = nested_metric_map[key] split_metrics = r.get("metrics", {}).get(split_key, {}) val = split_metrics.get(primary_key, split_metrics.get(fallback_key)) if val is not None: epochs.append(float(epoch)) values.append(float(val)) return epochs, values def get_best_metric(records: list, key: str) -> float | None: """Trả về giá trị tốt nhất của một metric.""" _, values = extract_series(records, key) if not values: return None return max(values) if key != "train_loss" else min(values) # ─── Plot functions ─────────────────────────────────────────────────────────── def plot_metric_curves(all_data: dict, metric_key: str, output_dir: str): """Vẽ đường cong một metric cho tất cả variant.""" label = METRICS_LABELS.get(metric_key, metric_key) minimize = metric_key == "train_loss" fig, ax = plt.subplots(figsize=(11, 6)) plotted = 0 for variant, info in all_data.items(): if info is None: continue epochs, values = extract_series(info["records"], metric_key) if not epochs: continue ax.plot( epochs, values, color=COLORS[variant], linewidth=2.5, marker=MARKERS[variant], markersize=7, label=f"{variant} (best={min(values) if minimize else max(values):.3f})" ) plotted += 1 if plotted == 0: plt.close(fig) print(f"[SKIP] {label}: không có dữ liệu") return ax.set_title(f"{label} — So sánh 5 Variant", fontsize=15, fontweight="bold", pad=14) ax.set_xlabel("Epoch", fontsize=12) ax.set_ylabel(label, fontsize=12) ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) if metric_key != "train_loss": ax.set_ylim(bottom=0) ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) ax.legend(loc="best", fontsize=11, framealpha=0.9) ax.grid(True, alpha=0.3) fig.tight_layout() fname = os.path.join(output_dir, f"compare_{metric_key}.png") fig.savefig(fname, dpi=150, bbox_inches="tight") plt.close(fig) print(f"[✓] Saved: {fname}") def plot_final_bar(all_data: dict, output_dir: str): """ Bar chart so sánh kết quả cuối (best) của từng model trên 4 metrics: Accuracy, F1, BLEU-4, BERTScore. """ metric_keys = ["val_accuracy_normalized", "val_f1_normalized", "val_bleu4_normalized", "val_bert_score_raw"] metric_labels = ["Accuracy", "F1", "BLEU-4", "BERTScore"] variants_with_data = [v for v in VARIANTS if all_data.get(v)] if not variants_with_data: print("[SKIP] Final bar chart: không có dữ liệu") return x = np.arange(len(metric_labels)) w = 0.8 / len(variants_with_data) fig, ax = plt.subplots(figsize=(13, 7)) for i, variant in enumerate(variants_with_data): info = all_data[variant] values = [get_best_metric(info["records"], k) or 0.0 for k in metric_keys] offset = (i - len(variants_with_data) / 2 + 0.5) * w bars = ax.bar(x + offset, values, w, label=variant, color=COLORS[variant], alpha=0.88) # Hiển thị số liệu trên đầu cột for bar, val in zip(bars, values): if val > 0: ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.008, f"{val:.1%}", ha="center", va="bottom", fontsize=8.5, fontweight="bold" ) ax.set_title("Kết quả tốt nhất — So sánh 5 Variant", fontsize=15, fontweight="bold", pad=14) ax.set_xticks(x) ax.set_xticklabels(metric_labels, fontsize=12) ax.set_ylabel("Score", fontsize=12) ax.set_ylim(0, 1.10) ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) ax.legend(loc="upper right", fontsize=11, framealpha=0.9) ax.grid(True, alpha=0.3, axis="y") fig.tight_layout() fname = os.path.join(output_dir, "compare_final_bar.png") fig.savefig(fname, dpi=150, bbox_inches="tight") plt.close(fig) print(f"[✓] Saved: {fname}") def plot_radar(all_data: dict, output_dir: str): """Radar chart so sánh 5 model trên 5 chiều.""" metric_keys = ["val_accuracy_normalized", "val_f1_normalized", "val_bleu4_normalized", "val_bert_score_raw", "val_semantic_raw"] metric_labels = ["Accuracy", "F1", "BLEU-4", "BERTScore", "Semantic"] variants_with_data = [v for v in VARIANTS if all_data.get(v)] if len(variants_with_data) < 2: return N = len(metric_labels) angles = [n / float(N) * 2 * np.pi for n in range(N)] angles += angles[:1] fig, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(polar=True)) ax.set_theta_offset(np.pi / 2) ax.set_theta_direction(-1) ax.set_xticks(angles[:-1]) ax.set_xticklabels(metric_labels, fontsize=12) ax.set_ylim(0, 1) ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) for variant in variants_with_data: info = all_data[variant] values = [get_best_metric(info["records"], k) or 0.0 for k in metric_keys] values += values[:1] ax.plot(angles, values, linewidth=2.5, color=COLORS[variant], label=variant, marker=MARKERS[variant]) ax.fill(angles, values, alpha=0.08, color=COLORS[variant]) ax.set_title("Radar — So sánh 5 Variant (Best per Metric)", fontsize=14, fontweight="bold", y=1.12) ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.15), fontsize=11) fig.tight_layout() fname = os.path.join(output_dir, "compare_radar.png") fig.savefig(fname, dpi=150, bbox_inches="tight") plt.close(fig) print(f"[✓] Saved: {fname}") def plot_loss_comparison(all_data: dict, output_dir: str): """Train Loss của tất cả variant trên cùng trục.""" plot_metric_curves(all_data, "train_loss", output_dir) def print_summary_table(all_data: dict): """In bảng tóm tắt ra console.""" metric_keys = ["val_accuracy_normalized", "val_f1_normalized", "val_bleu4_normalized", "val_bert_score_raw", "val_semantic_raw"] metric_short = ["Accuracy", "F1", "BLEU-4", "BERT", "Semantic"] header = f"{'Model':<8}" + "".join(f"{m:>12}" for m in metric_short) print("\n" + "═" * (8 + 12 * len(metric_short))) print(" 📊 FINAL COMPARISON — ALL VARIANTS") print("═" * (8 + 12 * len(metric_short))) print(f" {header}") print("─" * (8 + 12 * len(metric_short))) for variant in VARIANTS: info = all_data.get(variant) if info is None: print(f" {variant:<8}" + "".join(f"{'N/A':>12}" for _ in metric_keys)) continue row = f" {variant:<8}" for k in metric_keys: best = get_best_metric(info["records"], k) row += f"{best:>12.2%}" if best is not None else f"{'N/A':>12}" print(row) print("═" * (8 + 12 * len(metric_short)) + "\n") def print_split_summary_table(all_data: dict): """In bảng tóm tắt theo protocol closed/open.""" metric_keys = [ "val_closed_accuracy", "val_closed_em", "val_closed_f1", "val_open_semantic", "val_open_bertscore", ] metric_short = ["Closed Acc", "Closed EM", "Closed F1", "Open Sem", "Open BERT"] header = f"{'Model':<8}" + "".join(f"{m:>12}" for m in metric_short) print("\n" + "═" * (8 + 12 * len(metric_short))) print(" 📊 SPLIT EVALUATION — CLOSED VS OPEN") print("═" * (8 + 12 * len(metric_short))) print(f" {header}") print("─" * (8 + 12 * len(metric_short))) for variant in VARIANTS: info = all_data.get(variant) if info is None: print(f" {variant:<8}" + "".join(f"{'N/A':>12}" for _ in metric_keys)) continue row = f" {variant:<8}" for k in metric_keys: best = get_best_metric(info["records"], k) row += f"{best:>12.2%}" if best is not None else f"{'N/A':>12}" print(row) print("═" * (8 + 12 * len(metric_short)) + "\n") # ─── Main ───────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="So sánh 5 variant Medical VQA") parser.add_argument("--log_dir", default="logs/medical_vqa/history", help="Thư mục gốc chứa history (default: logs/medical_vqa/history)") parser.add_argument("--out", default="results/charts", help="Thư mục lưu biểu đồ (default: results/charts)") args = parser.parse_args() os.makedirs(args.out, exist_ok=True) print(f"\n[INFO] Tìm history tại: {args.log_dir}") print("─" * 60) # Thu thập dữ liệu từ tất cả variant all_data: dict = {} for variant in VARIANTS: all_data[variant] = find_latest_history(args.log_dir, variant) available = [v for v in VARIANTS if all_data[v]] print(f"\n[INFO] Có dữ liệu: {available}") if not available: print("[ERROR] Không tìm thấy bất kỳ history.json nào. Hãy train trước!") return print(f"\n[INFO] Đang vẽ biểu đồ → {args.out}/") print("─" * 60) # 1. Accuracy curves plot_metric_curves(all_data, "val_accuracy_normalized", args.out) # 2. F1 curves plot_metric_curves(all_data, "val_f1_normalized", args.out) # 3. BLEU-4 curves plot_metric_curves(all_data, "val_bleu4_normalized", args.out) # 4. Train loss plot_loss_comparison(all_data, args.out) # 5. BERTScore plot_metric_curves(all_data, "val_bert_score_raw", args.out) # 6. Bar chart tổng hợp plot_final_bar(all_data, args.out) # 7. Radar chart plot_radar(all_data, args.out) # 8. Protocol chấm riêng closed/open plot_metric_curves(all_data, "val_closed_accuracy", args.out) plot_metric_curves(all_data, "val_closed_em", args.out) plot_metric_curves(all_data, "val_closed_f1", args.out) plot_metric_curves(all_data, "val_open_semantic", args.out) plot_metric_curves(all_data, "val_open_bertscore", args.out) # In bảng tóm tắt print_summary_table(all_data) print_split_summary_table(all_data) print(f"[DONE] Tất cả biểu đồ đã lưu tại: {args.out}/") charts = glob.glob(os.path.join(args.out, "compare_*.png")) for c in sorted(charts): print(f" 📊 {os.path.basename(c)}") if __name__ == "__main__": main()