Spaces:
Paused
Paused
| """ | |
| compare_models.py — Vẽ biểu đồ so sánh 5 variant sau khi training xong. | |
| Cách dùng: | |
| python scripts/compare_models.py # auto-tìm tất cả history | |
| python scripts/compare_models.py --log_dir logs/history # chỉ định thư mục | |
| python scripts/compare_models.py --out results/charts # thư mục lưu chart | |
| Tự động tìm file history.json theo pattern: | |
| logs/history/{VARIANT}/{timestamp}/history.json | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import glob | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mticker | |
| import numpy as np | |
| # ─── Cấu hình ──────────────────────────────────────────────────────────────── | |
| VARIANTS = ["A1", "A2", "B1", "B2", "DPO", "PPO"] | |
| COLORS = { | |
| "A1": "#2ecc71", # xanh lá | |
| "A2": "#3498db", # xanh dương | |
| "B1": "#e67e22", # cam | |
| "B2": "#9b59b6", # tím | |
| "DPO": "#e74c3c", # đỏ | |
| "PPO": "#1abc9c", # xanh ngoc | |
| } | |
| MARKERS = { | |
| "A1": "o", "A2": "s", "B1": "^", "B2": "D", "DPO": "P", "PPO": "X" | |
| } | |
| METRICS_LABELS = { | |
| "val_accuracy_normalized": "Accuracy", | |
| "val_f1_normalized": "F1 Score", | |
| "val_bleu4_normalized": "BLEU-4", | |
| "val_bert_score_raw": "BERTScore", | |
| "val_semantic_raw": "Semantic Score", | |
| "val_closed_accuracy": "Closed Accuracy", | |
| "val_closed_em": "Closed EM", | |
| "val_closed_f1": "Closed F1", | |
| "val_open_semantic": "Open Semantic", | |
| "val_open_bertscore": "Open BERTScore", | |
| "val_open_f1": "Open F1", | |
| "val_open_rouge_l": "Open ROUGE-L", | |
| "train_loss": "Train Loss", | |
| } | |
| # ─── Helpers ────────────────────────────────────────────────────────────────── | |
| def find_latest_history(log_dir: str, variant: str) -> dict | None: | |
| """ | |
| Tìm file history.json mới nhất cho một variant. | |
| Hỗ trợ cả 2 format: | |
| • logs/history/{VARIANT}/{timestamp}/history.json (MedicalVQATrainer) | |
| • logs/history/{VARIANT}/history.json (flat) | |
| """ | |
| patterns = [ | |
| os.path.join(log_dir, variant, "**", "history.json"), | |
| os.path.join(log_dir, variant, "history.json"), | |
| os.path.join(log_dir, "**", variant, "**", "history.json"), | |
| ] | |
| found = [] | |
| for pat in patterns: | |
| found.extend(glob.glob(pat, recursive=True)) | |
| if not found: | |
| return None | |
| # Lấy file mới nhất theo mtime | |
| latest = max(found, key=os.path.getmtime) | |
| try: | |
| with open(latest, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| print(f"[✓] {variant}: {latest} ({len(data)} records)") | |
| return {"path": latest, "records": data} | |
| except Exception as e: | |
| print(f"[✗] {variant}: đọc thất bại — {e}") | |
| return None | |
| def extract_series(records: list, key: str) -> tuple[list, list]: | |
| """Trích xuất (epochs, values) từ list records.""" | |
| nested_metric_map = { | |
| "val_closed_accuracy": ("closed", "accuracy_normalized", "accuracy"), | |
| "val_closed_em": ("closed", "em_normalized", "em"), | |
| "val_closed_f1": ("closed", "f1_normalized", "f1"), | |
| "val_open_semantic": ("open", "semantic_raw", "semantic"), | |
| "val_open_bertscore": ("open", "bert_score_raw", "bert_score"), | |
| "val_open_f1": ("open", "f1_normalized", "f1"), | |
| "val_open_rouge_l": ("open", "rouge_l_normalized", "rouge_l"), | |
| } | |
| epochs, values = [], [] | |
| for r in records: | |
| # Hỗ trợ cả HuggingFace log format (có 'epoch' float) và MedicalVQATrainer format | |
| epoch = r.get("epoch") | |
| if epoch is None: | |
| continue | |
| val = r.get(key) | |
| if val is None: | |
| # Thử alias cho HF SFTTrainer/DPOTrainer logs | |
| aliases = { | |
| "val_accuracy_normalized": ["eval_accuracy", "eval_vqa_accuracy"], | |
| "val_f1_normalized": ["eval_f1"], | |
| "val_bleu4_normalized": ["eval_bleu4", "eval_bleu"], | |
| "val_bert_score_raw": ["eval_bertscore", "eval_bert_score"], | |
| "val_semantic_raw": ["eval_semantic"], | |
| "val_closed_accuracy": ["eval_closed_accuracy"], | |
| "val_closed_em": ["eval_closed_em"], | |
| "val_closed_f1": ["eval_closed_f1"], | |
| "val_open_semantic": ["eval_open_semantic"], | |
| "val_open_bertscore": ["eval_open_bertscore"], | |
| "val_open_f1": ["eval_open_f1"], | |
| "val_open_rouge_l": ["eval_open_rouge_l"], | |
| "train_loss": ["loss", "train/loss"], | |
| } | |
| for alias in aliases.get(key, []): | |
| val = r.get(alias) | |
| if val is not None: | |
| break | |
| if val is None and key in nested_metric_map: | |
| split_key, primary_key, fallback_key = nested_metric_map[key] | |
| split_metrics = r.get("metrics", {}).get(split_key, {}) | |
| val = split_metrics.get(primary_key, split_metrics.get(fallback_key)) | |
| if val is not None: | |
| epochs.append(float(epoch)) | |
| values.append(float(val)) | |
| return epochs, values | |
| def get_best_metric(records: list, key: str) -> float | None: | |
| """Trả về giá trị tốt nhất của một metric.""" | |
| _, values = extract_series(records, key) | |
| if not values: | |
| return None | |
| return max(values) if key != "train_loss" else min(values) | |
| # ─── Plot functions ─────────────────────────────────────────────────────────── | |
| def plot_metric_curves(all_data: dict, metric_key: str, output_dir: str): | |
| """Vẽ đường cong một metric cho tất cả variant.""" | |
| label = METRICS_LABELS.get(metric_key, metric_key) | |
| minimize = metric_key == "train_loss" | |
| fig, ax = plt.subplots(figsize=(11, 6)) | |
| plotted = 0 | |
| for variant, info in all_data.items(): | |
| if info is None: | |
| continue | |
| epochs, values = extract_series(info["records"], metric_key) | |
| if not epochs: | |
| continue | |
| ax.plot( | |
| epochs, values, | |
| color=COLORS[variant], linewidth=2.5, | |
| marker=MARKERS[variant], markersize=7, | |
| label=f"{variant} (best={min(values) if minimize else max(values):.3f})" | |
| ) | |
| plotted += 1 | |
| if plotted == 0: | |
| plt.close(fig) | |
| print(f"[SKIP] {label}: không có dữ liệu") | |
| return | |
| ax.set_title(f"{label} — So sánh 5 Variant", fontsize=15, fontweight="bold", pad=14) | |
| ax.set_xlabel("Epoch", fontsize=12) | |
| ax.set_ylabel(label, fontsize=12) | |
| ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) | |
| if metric_key != "train_loss": | |
| ax.set_ylim(bottom=0) | |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) | |
| ax.legend(loc="best", fontsize=11, framealpha=0.9) | |
| ax.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fname = os.path.join(output_dir, f"compare_{metric_key}.png") | |
| fig.savefig(fname, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"[✓] Saved: {fname}") | |
| def plot_final_bar(all_data: dict, output_dir: str): | |
| """ | |
| Bar chart so sánh kết quả cuối (best) của từng model | |
| trên 4 metrics: Accuracy, F1, BLEU-4, BERTScore. | |
| """ | |
| metric_keys = ["val_accuracy_normalized", "val_f1_normalized", | |
| "val_bleu4_normalized", "val_bert_score_raw"] | |
| metric_labels = ["Accuracy", "F1", "BLEU-4", "BERTScore"] | |
| variants_with_data = [v for v in VARIANTS if all_data.get(v)] | |
| if not variants_with_data: | |
| print("[SKIP] Final bar chart: không có dữ liệu") | |
| return | |
| x = np.arange(len(metric_labels)) | |
| w = 0.8 / len(variants_with_data) | |
| fig, ax = plt.subplots(figsize=(13, 7)) | |
| for i, variant in enumerate(variants_with_data): | |
| info = all_data[variant] | |
| values = [get_best_metric(info["records"], k) or 0.0 for k in metric_keys] | |
| offset = (i - len(variants_with_data) / 2 + 0.5) * w | |
| bars = ax.bar(x + offset, values, w, label=variant, | |
| color=COLORS[variant], alpha=0.88) | |
| # Hiển thị số liệu trên đầu cột | |
| for bar, val in zip(bars, values): | |
| if val > 0: | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + 0.008, | |
| f"{val:.1%}", ha="center", va="bottom", | |
| fontsize=8.5, fontweight="bold" | |
| ) | |
| ax.set_title("Kết quả tốt nhất — So sánh 5 Variant", | |
| fontsize=15, fontweight="bold", pad=14) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(metric_labels, fontsize=12) | |
| ax.set_ylabel("Score", fontsize=12) | |
| ax.set_ylim(0, 1.10) | |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) | |
| ax.legend(loc="upper right", fontsize=11, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| fig.tight_layout() | |
| fname = os.path.join(output_dir, "compare_final_bar.png") | |
| fig.savefig(fname, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"[✓] Saved: {fname}") | |
| def plot_radar(all_data: dict, output_dir: str): | |
| """Radar chart so sánh 5 model trên 5 chiều.""" | |
| metric_keys = ["val_accuracy_normalized", "val_f1_normalized", | |
| "val_bleu4_normalized", "val_bert_score_raw", | |
| "val_semantic_raw"] | |
| metric_labels = ["Accuracy", "F1", "BLEU-4", "BERTScore", "Semantic"] | |
| variants_with_data = [v for v in VARIANTS if all_data.get(v)] | |
| if len(variants_with_data) < 2: | |
| return | |
| N = len(metric_labels) | |
| angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(polar=True)) | |
| ax.set_theta_offset(np.pi / 2) | |
| ax.set_theta_direction(-1) | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(metric_labels, fontsize=12) | |
| ax.set_ylim(0, 1) | |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) | |
| for variant in variants_with_data: | |
| info = all_data[variant] | |
| values = [get_best_metric(info["records"], k) or 0.0 for k in metric_keys] | |
| values += values[:1] | |
| ax.plot(angles, values, linewidth=2.5, | |
| color=COLORS[variant], label=variant, marker=MARKERS[variant]) | |
| ax.fill(angles, values, alpha=0.08, color=COLORS[variant]) | |
| ax.set_title("Radar — So sánh 5 Variant (Best per Metric)", | |
| fontsize=14, fontweight="bold", y=1.12) | |
| ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.15), fontsize=11) | |
| fig.tight_layout() | |
| fname = os.path.join(output_dir, "compare_radar.png") | |
| fig.savefig(fname, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"[✓] Saved: {fname}") | |
| def plot_loss_comparison(all_data: dict, output_dir: str): | |
| """Train Loss của tất cả variant trên cùng trục.""" | |
| plot_metric_curves(all_data, "train_loss", output_dir) | |
| def print_summary_table(all_data: dict): | |
| """In bảng tóm tắt ra console.""" | |
| metric_keys = ["val_accuracy_normalized", "val_f1_normalized", | |
| "val_bleu4_normalized", "val_bert_score_raw", | |
| "val_semantic_raw"] | |
| metric_short = ["Accuracy", "F1", "BLEU-4", "BERT", "Semantic"] | |
| header = f"{'Model':<8}" + "".join(f"{m:>12}" for m in metric_short) | |
| print("\n" + "═" * (8 + 12 * len(metric_short))) | |
| print(" 📊 FINAL COMPARISON — ALL VARIANTS") | |
| print("═" * (8 + 12 * len(metric_short))) | |
| print(f" {header}") | |
| print("─" * (8 + 12 * len(metric_short))) | |
| for variant in VARIANTS: | |
| info = all_data.get(variant) | |
| if info is None: | |
| print(f" {variant:<8}" + "".join(f"{'N/A':>12}" for _ in metric_keys)) | |
| continue | |
| row = f" {variant:<8}" | |
| for k in metric_keys: | |
| best = get_best_metric(info["records"], k) | |
| row += f"{best:>12.2%}" if best is not None else f"{'N/A':>12}" | |
| print(row) | |
| print("═" * (8 + 12 * len(metric_short)) + "\n") | |
| def print_split_summary_table(all_data: dict): | |
| """In bảng tóm tắt theo protocol closed/open.""" | |
| metric_keys = [ | |
| "val_closed_accuracy", | |
| "val_closed_em", | |
| "val_closed_f1", | |
| "val_open_semantic", | |
| "val_open_bertscore", | |
| ] | |
| metric_short = ["Closed Acc", "Closed EM", "Closed F1", "Open Sem", "Open BERT"] | |
| header = f"{'Model':<8}" + "".join(f"{m:>12}" for m in metric_short) | |
| print("\n" + "═" * (8 + 12 * len(metric_short))) | |
| print(" 📊 SPLIT EVALUATION — CLOSED VS OPEN") | |
| print("═" * (8 + 12 * len(metric_short))) | |
| print(f" {header}") | |
| print("─" * (8 + 12 * len(metric_short))) | |
| for variant in VARIANTS: | |
| info = all_data.get(variant) | |
| if info is None: | |
| print(f" {variant:<8}" + "".join(f"{'N/A':>12}" for _ in metric_keys)) | |
| continue | |
| row = f" {variant:<8}" | |
| for k in metric_keys: | |
| best = get_best_metric(info["records"], k) | |
| row += f"{best:>12.2%}" if best is not None else f"{'N/A':>12}" | |
| print(row) | |
| print("═" * (8 + 12 * len(metric_short)) + "\n") | |
| # ─── Main ───────────────────────────────────────────────────────────────────── | |
| def main(): | |
| parser = argparse.ArgumentParser(description="So sánh 5 variant Medical VQA") | |
| parser.add_argument("--log_dir", default="logs/medical_vqa/history", | |
| help="Thư mục gốc chứa history (default: logs/medical_vqa/history)") | |
| parser.add_argument("--out", default="results/charts", | |
| help="Thư mục lưu biểu đồ (default: results/charts)") | |
| args = parser.parse_args() | |
| os.makedirs(args.out, exist_ok=True) | |
| print(f"\n[INFO] Tìm history tại: {args.log_dir}") | |
| print("─" * 60) | |
| # Thu thập dữ liệu từ tất cả variant | |
| all_data: dict = {} | |
| for variant in VARIANTS: | |
| all_data[variant] = find_latest_history(args.log_dir, variant) | |
| available = [v for v in VARIANTS if all_data[v]] | |
| print(f"\n[INFO] Có dữ liệu: {available}") | |
| if not available: | |
| print("[ERROR] Không tìm thấy bất kỳ history.json nào. Hãy train trước!") | |
| return | |
| print(f"\n[INFO] Đang vẽ biểu đồ → {args.out}/") | |
| print("─" * 60) | |
| # 1. Accuracy curves | |
| plot_metric_curves(all_data, "val_accuracy_normalized", args.out) | |
| # 2. F1 curves | |
| plot_metric_curves(all_data, "val_f1_normalized", args.out) | |
| # 3. BLEU-4 curves | |
| plot_metric_curves(all_data, "val_bleu4_normalized", args.out) | |
| # 4. Train loss | |
| plot_loss_comparison(all_data, args.out) | |
| # 5. BERTScore | |
| plot_metric_curves(all_data, "val_bert_score_raw", args.out) | |
| # 6. Bar chart tổng hợp | |
| plot_final_bar(all_data, args.out) | |
| # 7. Radar chart | |
| plot_radar(all_data, args.out) | |
| # 8. Protocol chấm riêng closed/open | |
| plot_metric_curves(all_data, "val_closed_accuracy", args.out) | |
| plot_metric_curves(all_data, "val_closed_em", args.out) | |
| plot_metric_curves(all_data, "val_closed_f1", args.out) | |
| plot_metric_curves(all_data, "val_open_semantic", args.out) | |
| plot_metric_curves(all_data, "val_open_bertscore", args.out) | |
| # In bảng tóm tắt | |
| print_summary_table(all_data) | |
| print_split_summary_table(all_data) | |
| print(f"[DONE] Tất cả biểu đồ đã lưu tại: {args.out}/") | |
| charts = glob.glob(os.path.join(args.out, "compare_*.png")) | |
| for c in sorted(charts): | |
| print(f" 📊 {os.path.basename(c)}") | |
| if __name__ == "__main__": | |
| main() | |