| | |
| | """ |
| | Aggregate metrics from checkpoint JSONs into a single table. |
| | Saves results locally to the 'table_result' directory. |
| | """ |
| |
|
| | import os |
| | import json |
| | import argparse |
| | import csv |
| | import numpy as np |
| | from typing import Dict, Any, List, Tuple, Set, Union |
| | import pandas as pd |
| | from openpyxl import load_workbook |
| | from openpyxl.styles import PatternFill, Font, Alignment |
| |
|
| | |
| | |
| | try: |
| | from evaluate_aime_raw_vs_finetuned import find_best_checkpoint |
| | except ImportError: |
| | def find_best_checkpoint(path): |
| | return None, 0 |
| |
|
| | Scalar = Union[int, float, str] |
| |
|
| | def is_scalar(x: Any) -> bool: |
| | return isinstance(x, (int, float, str)) |
| |
|
| | def load_metrics_from_json(json_path: str) -> Dict[str, Scalar]: |
| | """Load scalar metrics from all_cases.json.""" |
| | with open(json_path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| |
|
| | if isinstance(data, dict) and "metrics" in data and isinstance(data["metrics"], dict): |
| | metrics_dict = data["metrics"] |
| | else: |
| | metrics_dict = data |
| |
|
| | out: Dict[str, Scalar] = {} |
| | if isinstance(metrics_dict, dict): |
| | for k, v in metrics_dict.items(): |
| | if is_scalar(v): |
| | out[k] = v |
| | return out |
| |
|
| | def collect_all_rows(root_dir: str, run: str, best_checkpoint: str = None, model_name: str = "qwen2.5-3B") -> Tuple[List[Dict[str, Scalar]], List[str]]: |
| | """Walk the checkpoints directory and collect rows + column names.""" |
| | rows: List[Dict[str, Scalar]] = [] |
| | all_metric_cols: Set[str] = set() |
| |
|
| | if not os.path.isdir(root_dir): |
| | |
| | print(f"[WARN] Root directory not found: {root_dir}") |
| | return [], ["checkpoint"] |
| | |
| | |
| | ckpt_path = os.path.join(root_dir, "raw_model") |
| | row: Dict[str, Scalar] = {"checkpoint": f"{model_name}"} |
| | |
| | if os.path.isdir(ckpt_path): |
| | for dataset_name in sorted(os.listdir(ckpt_path)): |
| | if "neu" in dataset_name.lower(): |
| | continue |
| | dataset_path = os.path.join(ckpt_path, dataset_name) |
| | dataset_name = dataset_name.lower() |
| | if not os.path.isdir(dataset_path): |
| | continue |
| |
|
| | json_path = os.path.join(dataset_path, "raw_results_train_all.json") |
| | if not os.path.isfile(json_path): |
| | continue |
| |
|
| | try: |
| | metrics = load_metrics_from_json(json_path) |
| | except Exception as e: |
| | print(f"[WARN] Failed to read {json_path}: {e}") |
| | continue |
| |
|
| | _process_metrics_into_row(row, all_metric_cols, dataset_name, metrics) |
| | |
| | |
| | if len(row) > 1: |
| | rows.append(row) |
| |
|
| | |
| | root_dir = os.path.join(root_dir, run) |
| | if not os.path.isdir(root_dir): |
| | print(f"[WARN] Run directory not found: {root_dir}") |
| | ordered_cols = ["checkpoint"] + sorted(all_metric_cols) |
| | return rows, ordered_cols |
| |
|
| | try: |
| | |
| | training_step = [] |
| | for d in os.listdir(root_dir): |
| | if "-" in d: |
| | try: |
| | step_val = int(d.split("-")[-1]) |
| | training_step.append(step_val) |
| | except ValueError: |
| | continue |
| | |
| | for step in sorted(training_step): |
| | ckpt_name = f"checkpoint-{step}" |
| | ckpt_path = os.path.join(root_dir, ckpt_name) |
| | |
| | if not os.path.isdir(ckpt_path): |
| | continue |
| | |
| | if best_checkpoint and ckpt_name == best_checkpoint: |
| | row: Dict[str, Scalar] = {"checkpoint": ckpt_name+"(best)"} |
| | else: |
| | row: Dict[str, Scalar] = {"checkpoint": ckpt_name} |
| |
|
| | for dataset_name in sorted(os.listdir(ckpt_path)): |
| | dataset_path = os.path.join(ckpt_path, dataset_name) |
| | dataset_name = dataset_name.lower() |
| | if not os.path.isdir(dataset_path): |
| | continue |
| | |
| | |
| | json_path = os.path.join(dataset_path, "all_cases.json") |
| | if not os.path.isfile(json_path): |
| | json_path = os.path.join(dataset_path, "all_casses.json") |
| | if not os.path.isfile(json_path): |
| | continue |
| | |
| | try: |
| | metrics = load_metrics_from_json(json_path) |
| | except Exception as e: |
| | print(f"[WARN] Failed to read {json_path}: {e}") |
| | continue |
| | |
| | _process_metrics_into_row(row, all_metric_cols, dataset_name, metrics) |
| |
|
| | rows.append(row) |
| | except Exception as e: |
| | print(f"[ERROR] Error walking directories: {e}") |
| |
|
| | ordered_cols = ["checkpoint"] + sorted(all_metric_cols) |
| | return rows, ordered_cols |
| |
|
| | def _process_metrics_into_row(row, all_metric_cols, dataset_name, metrics): |
| | """Helper to standardize metric naming logic.""" |
| | f1_flag = False |
| | possible_col_names = [] |
| | |
| | for metric_name, metric_value in metrics.items(): |
| | col_name = None |
| | if "accuracy" in metric_name or "hamming_accuracy" in metric_name: |
| | col_name = f"{dataset_name}_acc" |
| | elif "macro_f1" in metric_name or "f1_macro" in metric_name or "f1" in metric_name: |
| | col_name = f"{dataset_name}_f1" |
| | f1_flag = True |
| | elif "precision" in metric_name: |
| | col_name = f"{dataset_name}_precision" |
| | elif "recall" in metric_name: |
| | col_name = f"{dataset_name}_recall" |
| | elif "exact_match_accuracy" in metric_name: |
| | col_name = f"{dataset_name}_EM" |
| | |
| | if col_name: |
| | possible_col_names.append((col_name, metric_value)) |
| | |
| | for col_name, metric_value in possible_col_names: |
| | |
| | |
| | |
| | if f1_flag and ("f1" in col_name or "_f1" in col_name): |
| | if col_name not in row: |
| | row[col_name] = round(metric_value, 4) |
| | all_metric_cols.add(col_name) |
| | elif not f1_flag: |
| | if col_name not in row: |
| | row[col_name] = round(metric_value, 4) |
| | all_metric_cols.add(col_name) |
| |
|
| | def clean_sheet_name(name): |
| | """Ensure Excel sheet name is valid.""" |
| | if not name: return "Sheet1" |
| | invalid = ['\\', '/', '*', '?', ':', '[', ']'] |
| | for c in invalid: |
| | name = name.replace(c, '_') |
| | return name[:31] |
| |
|
| | def _checkpoint_key(name: str) -> str: |
| | """Normalize checkpoint name by stripping '(best)' etc.""" |
| | return name.split("(")[0].strip() if name else "" |
| |
|
| | def append_rows_in_place( |
| | rows: List[Dict[str, Scalar]], |
| | columns: List[str], |
| | out_path: str, |
| | sheet_name: str, |
| | best_checkpoint: str, |
| | model_name: str, |
| | old_sheet_name: str |
| | ) -> None: |
| | """Update an existing Excel sheet in-place.""" |
| | sheet_name = clean_sheet_name(sheet_name) |
| | final_columns = columns + ["better_datasets_than_raw"] |
| | metric_cols = [c for c in columns if c != "checkpoint"] |
| |
|
| | dataset_to_cols: Dict[str, List[str]] = {} |
| | for col in metric_cols: |
| | ds = col.split("_", 1)[0] |
| | dataset_to_cols.setdefault(ds, []).append(col) |
| |
|
| | baseline = None |
| | for r in rows: |
| | if r.get("checkpoint") == model_name: |
| | baseline = r |
| | break |
| |
|
| | wb = load_workbook(out_path) |
| |
|
| | if sheet_name in wb.sheetnames: |
| | ws = wb[sheet_name] |
| | else: |
| | ws = wb.create_sheet(title=sheet_name) |
| | for idx, col_name in enumerate(final_columns, start=1): |
| | ws.cell(row=1, column=idx, value=col_name) |
| |
|
| | |
| | header_map: Dict[str, int] = {} |
| | max_col = ws.max_column |
| | |
| | |
| | for idx, cell in enumerate(ws[1], start=1): |
| | if cell.value is not None: |
| | header_map[str(cell.value)] = idx |
| | |
| | |
| | for col_name in final_columns: |
| | if col_name not in header_map: |
| | max_col += 1 |
| | ws.cell(row=1, column=max_col, value=col_name) |
| | header_map[col_name] = max_col |
| |
|
| | col_index = header_map |
| | first_data_row = 2 |
| |
|
| | |
| | existing_row_map: Dict[str, int] = {} |
| | mx_row = ws.max_row |
| | |
| | if "checkpoint" in col_index: |
| | chk_col = col_index["checkpoint"] |
| | for r_idx in range(first_data_row, ws.max_row + 1): |
| | val = ws.cell(row=r_idx, column=chk_col).value |
| | if val: |
| | ck = _checkpoint_key(str(val)) |
| | existing_row_map[ck] = r_idx |
| | else: |
| | |
| | chk_col = 1 |
| |
|
| | green_fill = PatternFill(start_color="90EE90", end_color="90EE90", fill_type="solid") |
| | no_fill = PatternFill(fill_type=None) |
| | best_font = Font(color="008000", bold=True) |
| | normal_font = Font(color="000000") |
| |
|
| | for row_data in rows: |
| | ckpt_val = row_data.get("checkpoint") |
| | if ckpt_val is None: |
| | continue |
| |
|
| | ck_key = _checkpoint_key(str(ckpt_val)) |
| |
|
| | |
| | better_datasets_count = 0 |
| | metric_better_flags: Dict[str, bool] = {} |
| |
|
| | if baseline is not None: |
| | for col in metric_cols: |
| | v = row_data.get(col) |
| | b = baseline.get(col) |
| | improved = False |
| | try: |
| | if v is not None and b is not None: |
| | if float(v) > float(b): |
| | improved = True |
| | except Exception: |
| | improved = False |
| | metric_better_flags[col] = improved |
| |
|
| | for ds, ds_cols in dataset_to_cols.items(): |
| | if any(metric_better_flags.get(c, False) for c in ds_cols): |
| | better_datasets_count += 1 |
| |
|
| | |
| | if ck_key in existing_row_map: |
| | excel_row = existing_row_map[ck_key] |
| | else: |
| | mx_row += 1 |
| | excel_row = mx_row |
| | existing_row_map[ck_key] = excel_row |
| |
|
| | |
| | for col_name in final_columns: |
| | if col_name not in col_index: continue |
| | |
| | idx = col_index[col_name] |
| | cell = ws.cell(row=excel_row, column=idx) |
| |
|
| | if col_name == "better_datasets_than_raw": |
| | value = better_datasets_count |
| | else: |
| | value = row_data.get(col_name, "") |
| |
|
| | cell.value = value |
| | cell.alignment = Alignment(horizontal='center', vertical='center') |
| |
|
| | |
| | if col_name in metric_cols: |
| | if metric_better_flags.get(col_name, False): |
| | cell.fill = green_fill |
| | else: |
| | cell.fill = no_fill |
| |
|
| | |
| | checkpoint_cell = ws.cell(row=excel_row, column=chk_col) |
| | if best_checkpoint and _checkpoint_key(best_checkpoint) == ck_key: |
| | checkpoint_cell.font = best_font |
| | else: |
| | checkpoint_cell.font = normal_font |
| | |
| | |
| | if best_checkpoint and _checkpoint_key(best_checkpoint) == ck_key: |
| | checkpoint_cell.value = best_checkpoint + "(best)" if "(best)" not in str(checkpoint_cell.value) else checkpoint_cell.value |
| |
|
| | |
| | for col_name, idx in col_index.items(): |
| | col_letter = ws.cell(row=1, column=idx).column_letter |
| | ws.column_dimensions[col_letter].width = max(len(col_name) + 2, 12) |
| |
|
| | wb.save(out_path) |
| | wb.close() |
| |
|
| |
|
| | def write_excel( |
| | rows: List[Dict[str, Scalar]], |
| | columns: List[str], |
| | out_path: str, |
| | sheet_name: str = "Sheet1", |
| | old_sheet_name: str = None, |
| | best_checkpoint: str = None, |
| | model_name: str = "qwen2.5-3B" |
| | ) -> None: |
| | """ |
| | Write data to a local Excel file. |
| | Creates file if it doesn't exist, otherwise appends/updates via append_rows_in_place. |
| | """ |
| | sheet_name = clean_sheet_name(sheet_name) |
| |
|
| | if os.path.exists(out_path): |
| | try: |
| | append_rows_in_place( |
| | rows=rows, |
| | columns=columns, |
| | out_path=out_path, |
| | sheet_name=sheet_name, |
| | best_checkpoint=best_checkpoint, |
| | model_name=model_name, |
| | old_sheet_name=old_sheet_name |
| | ) |
| | print(f"Excel updated: {out_path} (sheet: {sheet_name})") |
| | return |
| | except Exception as e: |
| | print(f"[WARN] Failed to update existing Excel: {e}. Attempting full rewrite.") |
| |
|
| | |
| | table = [{col: row.get(col, "") for col in columns} for row in rows] |
| | df = pd.DataFrame(table, columns=columns) |
| |
|
| | metric_cols = [c for c in df.columns if c != "checkpoint"] |
| | if metric_cols: |
| | df[metric_cols] = df[metric_cols].apply(pd.to_numeric, errors="coerce") |
| | |
| |
|
| | |
| | has_baseline = False |
| | better_mask = None |
| |
|
| | if "checkpoint" in df.columns and (df["checkpoint"] == f"{model_name}").any() and metric_cols: |
| | has_baseline = True |
| | raw_idx = df.index[df["checkpoint"] == f"{model_name}"][0] |
| | raw_values = df.loc[raw_idx, metric_cols] |
| | better_mask = df[metric_cols].gt(raw_values).fillna(False) |
| |
|
| | dataset_to_cols = {} |
| | for col in metric_cols: |
| | dataset_name = col.split("_", 1)[0] |
| | dataset_to_cols.setdefault(dataset_name, []).append(col) |
| |
|
| | better_datasets_counts = [] |
| | for idx, row_bool in better_mask.iterrows(): |
| | count = 0 |
| | for ds, ds_cols in dataset_to_cols.items(): |
| | if row_bool[ds_cols].any(): |
| | count += 1 |
| | better_datasets_counts.append(count) |
| | df["better_datasets_than_raw"] = better_datasets_counts |
| | else: |
| | df["better_datasets_than_raw"] = 0 |
| |
|
| | final_columns = columns + ["better_datasets_than_raw"] |
| |
|
| | |
| | with pd.ExcelWriter(out_path, engine="openpyxl") as writer: |
| | df.to_excel(writer, index=False, sheet_name=sheet_name, columns=final_columns) |
| | ws = writer.sheets[sheet_name] |
| |
|
| | green_fill = PatternFill(start_color="90EE90", end_color="90EE90", fill_type="solid") |
| | |
| | |
| | for r_idx, row in enumerate(ws.iter_rows(min_row=2), start=0): |
| | for cell in row: |
| | cell.alignment = Alignment(horizontal='center', vertical='center') |
| | |
| | |
| | if has_baseline and better_mask is not None: |
| | row_label = df.index[r_idx] |
| | for col_name in metric_cols: |
| | if col_name in df.columns: |
| | col_idx = df.columns.get_loc(col_name) |
| | |
| | |
| | if col_name in final_columns: |
| | excel_col_idx = final_columns.index(col_name) + 1 |
| | if bool(better_mask.loc[row_label, col_name]): |
| | ws.cell(row=r_idx+2, column=excel_col_idx).fill = green_fill |
| |
|
| | |
| | chk_val = df.iloc[r_idx]["checkpoint"] |
| | if best_checkpoint and _checkpoint_key(str(chk_val)) == _checkpoint_key(best_checkpoint): |
| | ws.cell(row=r_idx+2, column=final_columns.index("checkpoint")+1).font = Font(color="008000", bold=True) |
| |
|
| | print(f"Excel created: {out_path} (sheet: {sheet_name})") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Create a metrics table from checkpoint JSON files." |
| | ) |
| | parser.add_argument( |
| | "--root", |
| | type=str, |
| | default="./SFT/Evaluation/", |
| | help="Root directory containing checkpoints (default: checkpoints)", |
| | ) |
| | parser.add_argument( |
| | "--out_csv", |
| | type=str, |
| | default="./SFT/Evaluation//metrics_summary.xlsx", |
| | help="Output Excel file path (default: metrics_summary.xlsx)", |
| | ) |
| | parser.add_argument( |
| | "--run", |
| | type=str, |
| | default="SFT_dt12.11.19:13_e6_unsloth_Qwen2.5_14B_Instruct_bnb_4bit_bnb_4bit_lr5e-06_t0.0_r64_b4_SFT_Implementation", |
| | help="Name of the Excel sheet (subsheet) to write results into.", |
| | ) |
| | parser.add_argument( |
| | "--best_checkpoint", |
| | type=str, |
| | default=None, |
| | help="Name of checkpoint whose name in the first column will be colored green.", |
| | ) |
| | parser.add_argument( |
| | "--base_model_name", |
| | type=str, |
| | default="qwen2.5-3B", |
| | help="Name of the base model we trained on", |
| | ) |
| | parser.add_argument( |
| | "--base_result_dir", |
| | type=str, |
| | default="./SFT/results_sft_14b", |
| | help="Directory of the base model we trained on", |
| | ) |
| | parser.add_argument( |
| | "--train_data", |
| | type=str, |
| | default="UniADILR", |
| | help="Name of the training data that the model was trained on.", |
| | ) |
| |
|
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | |
| | output_dir = "table_result" |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | final_out_path = os.path.join(output_dir, os.path.basename(args.out_filename)) |
| |
|
| | |
| | if args.best_checkpoint is None: |
| | BASE_RESULTS_DIR = os.path.expanduser(args.base_result_dir) |
| | TRAINING_DIR = os.path.join(BASE_RESULTS_DIR, f"Training_{args.run}") |
| | FINAL_DIR = os.path.join(BASE_RESULTS_DIR, args.run) |
| | |
| | TRAINING_BASE = None |
| | if os.path.isdir(TRAINING_DIR): |
| | TRAINING_BASE = TRAINING_DIR |
| | elif os.path.isdir(FINAL_DIR): |
| | TRAINING_BASE = FINAL_DIR |
| | |
| | if TRAINING_BASE: |
| | best_path, _ = find_best_checkpoint(TRAINING_BASE) |
| | if best_path: |
| | args.best_checkpoint = os.path.basename(best_path) |
| |
|
| | |
| | print(f"Collecting metrics for run: {args.run}") |
| | rows, columns = collect_all_rows(args.root, args.run, args.best_checkpoint, args.base_model_name) |
| | |
| | if not rows: |
| | print("No metrics found. Exiting.") |
| | return |
| |
|
| | |
| | |
| | |
| | try: |
| | parts = args.run.split("e20_") |
| | if len(parts) >= 2: |
| | sheet_name = parts[0] + args.train_data + "_e20_" + parts[1] |
| | else: |
| | sheet_name = args.run |
| | except Exception: |
| | sheet_name = args.run |
| | |
| | old_sheet_name = args.run |
| |
|
| | |
| | write_excel( |
| | rows=rows, |
| | columns=columns, |
| | out_path=final_out_path, |
| | sheet_name=sheet_name, |
| | old_sheet_name=old_sheet_name, |
| | best_checkpoint=args.best_checkpoint, |
| | model_name=args.base_model_name, |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | main() |