#!/usr/bin/env python3 """ Aggregate metrics from checkpoint JSONs into a single table. Saves results locally to the 'table_result' directory. """ import os import json import argparse import csv import numpy as np from typing import Dict, Any, List, Tuple, Set, Union import pandas as pd from openpyxl import load_workbook from openpyxl.styles import PatternFill, Font, Alignment # Attempt to import the specific helper; if not found, we define a dummy placeholder # to ensure the script doesn't crash if the user only has this single file. try: from evaluate_aime_raw_vs_finetuned import find_best_checkpoint except ImportError: def find_best_checkpoint(path): return None, 0 Scalar = Union[int, float, str] def is_scalar(x: Any) -> bool: return isinstance(x, (int, float, str)) def load_metrics_from_json(json_path: str) -> Dict[str, Scalar]: """Load scalar metrics from all_cases.json.""" with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, dict) and "metrics" in data and isinstance(data["metrics"], dict): metrics_dict = data["metrics"] else: metrics_dict = data out: Dict[str, Scalar] = {} if isinstance(metrics_dict, dict): for k, v in metrics_dict.items(): if is_scalar(v): out[k] = v return out def collect_all_rows(root_dir: str, run: str, best_checkpoint: str = None, model_name: str = "qwen2.5-3B") -> Tuple[List[Dict[str, Scalar]], List[str]]: """Walk the checkpoints directory and collect rows + column names.""" rows: List[Dict[str, Scalar]] = [] all_metric_cols: Set[str] = set() if not os.path.isdir(root_dir): # Graceful exit or warning if root doesn't exist print(f"[WARN] Root directory not found: {root_dir}") return [], ["checkpoint"] # 1. Collect Raw Model Metrics ckpt_path = os.path.join(root_dir, "raw_model") row: Dict[str, Scalar] = {"checkpoint": f"{model_name}"} if os.path.isdir(ckpt_path): for dataset_name in sorted(os.listdir(ckpt_path)): if "neu" in dataset_name.lower(): continue dataset_path = os.path.join(ckpt_path, dataset_name) dataset_name = dataset_name.lower() if not os.path.isdir(dataset_path): continue json_path = os.path.join(dataset_path, "raw_results_train_all.json") if not os.path.isfile(json_path): continue try: metrics = load_metrics_from_json(json_path) except Exception as e: print(f"[WARN] Failed to read {json_path}: {e}") continue _process_metrics_into_row(row, all_metric_cols, dataset_name, metrics) # Only add raw model row if we actually found data if len(row) > 1: rows.append(row) # 2. Collect Checkpoint Metrics root_dir = os.path.join(root_dir, run) if not os.path.isdir(root_dir): print(f"[WARN] Run directory not found: {root_dir}") ordered_cols = ["checkpoint"] + sorted(all_metric_cols) return rows, ordered_cols try: # Filter for directories that look like steps (integers) training_step = [] for d in os.listdir(root_dir): if "-" in d: try: step_val = int(d.split("-")[-1]) training_step.append(step_val) except ValueError: continue for step in sorted(training_step): ckpt_name = f"checkpoint-{step}" ckpt_path = os.path.join(root_dir, ckpt_name) if not os.path.isdir(ckpt_path): continue if best_checkpoint and ckpt_name == best_checkpoint: row: Dict[str, Scalar] = {"checkpoint": ckpt_name+"(best)"} else: row: Dict[str, Scalar] = {"checkpoint": ckpt_name} for dataset_name in sorted(os.listdir(ckpt_path)): dataset_path = os.path.join(ckpt_path, dataset_name) dataset_name = dataset_name.lower() if not os.path.isdir(dataset_path): continue # Try common filenames json_path = os.path.join(dataset_path, "all_cases.json") if not os.path.isfile(json_path): json_path = os.path.join(dataset_path, "all_casses.json") # Typo fallback if not os.path.isfile(json_path): continue try: metrics = load_metrics_from_json(json_path) except Exception as e: print(f"[WARN] Failed to read {json_path}: {e}") continue _process_metrics_into_row(row, all_metric_cols, dataset_name, metrics) rows.append(row) except Exception as e: print(f"[ERROR] Error walking directories: {e}") ordered_cols = ["checkpoint"] + sorted(all_metric_cols) return rows, ordered_cols def _process_metrics_into_row(row, all_metric_cols, dataset_name, metrics): """Helper to standardize metric naming logic.""" f1_flag = False possible_col_names = [] for metric_name, metric_value in metrics.items(): col_name = None if "accuracy" in metric_name or "hamming_accuracy" in metric_name: col_name = f"{dataset_name}_acc" elif "macro_f1" in metric_name or "f1_macro" in metric_name or "f1" in metric_name: col_name = f"{dataset_name}_f1" f1_flag = True elif "precision" in metric_name: col_name = f"{dataset_name}_precision" elif "recall" in metric_name: col_name = f"{dataset_name}_recall" elif "exact_match_accuracy" in metric_name: col_name = f"{dataset_name}_EM" if col_name: possible_col_names.append((col_name, metric_value)) for col_name, metric_value in possible_col_names: # Priority logic: if F1 exists, prioritize it? # The original logic seemed to allow both, but checked flags. # Preserving original logic structure: if f1_flag and ("f1" in col_name or "_f1" in col_name): if col_name not in row: row[col_name] = round(metric_value, 4) all_metric_cols.add(col_name) elif not f1_flag: if col_name not in row: row[col_name] = round(metric_value, 4) all_metric_cols.add(col_name) def clean_sheet_name(name): """Ensure Excel sheet name is valid.""" if not name: return "Sheet1" invalid = ['\\', '/', '*', '?', ':', '[', ']'] for c in invalid: name = name.replace(c, '_') return name[:31] # Excel limit def _checkpoint_key(name: str) -> str: """Normalize checkpoint name by stripping '(best)' etc.""" return name.split("(")[0].strip() if name else "" def append_rows_in_place( rows: List[Dict[str, Scalar]], columns: List[str], out_path: str, sheet_name: str, best_checkpoint: str, model_name: str, old_sheet_name: str ) -> None: """Update an existing Excel sheet in-place.""" sheet_name = clean_sheet_name(sheet_name) final_columns = columns + ["better_datasets_than_raw"] metric_cols = [c for c in columns if c != "checkpoint"] dataset_to_cols: Dict[str, List[str]] = {} for col in metric_cols: ds = col.split("_", 1)[0] dataset_to_cols.setdefault(ds, []).append(col) baseline = None for r in rows: if r.get("checkpoint") == model_name: baseline = r break wb = load_workbook(out_path) if sheet_name in wb.sheetnames: ws = wb[sheet_name] else: ws = wb.create_sheet(title=sheet_name) for idx, col_name in enumerate(final_columns, start=1): ws.cell(row=1, column=idx, value=col_name) # Re-map headers in case columns changed or are in different order in file header_map: Dict[str, int] = {} max_col = ws.max_column # Read existing headers for idx, cell in enumerate(ws[1], start=1): if cell.value is not None: header_map[str(cell.value)] = idx # Add new columns if missing for col_name in final_columns: if col_name not in header_map: max_col += 1 ws.cell(row=1, column=max_col, value=col_name) header_map[col_name] = max_col col_index = header_map first_data_row = 2 # Map existing checkpoints to row numbers existing_row_map: Dict[str, int] = {} mx_row = ws.max_row if "checkpoint" in col_index: chk_col = col_index["checkpoint"] for r_idx in range(first_data_row, ws.max_row + 1): val = ws.cell(row=r_idx, column=chk_col).value if val: ck = _checkpoint_key(str(val)) existing_row_map[ck] = r_idx else: # Should not happen if file was created correctly chk_col = 1 green_fill = PatternFill(start_color="90EE90", end_color="90EE90", fill_type="solid") no_fill = PatternFill(fill_type=None) best_font = Font(color="008000", bold=True) normal_font = Font(color="000000") for row_data in rows: ckpt_val = row_data.get("checkpoint") if ckpt_val is None: continue ck_key = _checkpoint_key(str(ckpt_val)) # Calculate metrics comparison better_datasets_count = 0 metric_better_flags: Dict[str, bool] = {} if baseline is not None: for col in metric_cols: v = row_data.get(col) b = baseline.get(col) improved = False try: if v is not None and b is not None: if float(v) > float(b): improved = True except Exception: improved = False metric_better_flags[col] = improved for ds, ds_cols in dataset_to_cols.items(): if any(metric_better_flags.get(c, False) for c in ds_cols): better_datasets_count += 1 # Determine target row if ck_key in existing_row_map: excel_row = existing_row_map[ck_key] else: mx_row += 1 excel_row = mx_row existing_row_map[ck_key] = excel_row # Write data for col_name in final_columns: if col_name not in col_index: continue idx = col_index[col_name] cell = ws.cell(row=excel_row, column=idx) if col_name == "better_datasets_than_raw": value = better_datasets_count else: value = row_data.get(col_name, "") cell.value = value cell.alignment = Alignment(horizontal='center', vertical='center') # Conditional formatting for metrics if col_name in metric_cols: if metric_better_flags.get(col_name, False): cell.fill = green_fill else: cell.fill = no_fill # Highlight best checkpoint name checkpoint_cell = ws.cell(row=excel_row, column=chk_col) if best_checkpoint and _checkpoint_key(best_checkpoint) == ck_key: checkpoint_cell.font = best_font else: checkpoint_cell.font = normal_font # Ensure checkpoint name has (best) suffix if applicable if best_checkpoint and _checkpoint_key(best_checkpoint) == ck_key: checkpoint_cell.value = best_checkpoint + "(best)" if "(best)" not in str(checkpoint_cell.value) else checkpoint_cell.value # Auto-adjust column widths (simple approximation) for col_name, idx in col_index.items(): col_letter = ws.cell(row=1, column=idx).column_letter ws.column_dimensions[col_letter].width = max(len(col_name) + 2, 12) wb.save(out_path) wb.close() def write_excel( rows: List[Dict[str, Scalar]], columns: List[str], out_path: str, sheet_name: str = "Sheet1", old_sheet_name: str = None, best_checkpoint: str = None, model_name: str = "qwen2.5-3B" ) -> None: """ Write data to a local Excel file. Creates file if it doesn't exist, otherwise appends/updates via append_rows_in_place. """ sheet_name = clean_sheet_name(sheet_name) if os.path.exists(out_path): try: append_rows_in_place( rows=rows, columns=columns, out_path=out_path, sheet_name=sheet_name, best_checkpoint=best_checkpoint, model_name=model_name, old_sheet_name=old_sheet_name ) print(f"Excel updated: {out_path} (sheet: {sheet_name})") return except Exception as e: print(f"[WARN] Failed to update existing Excel: {e}. Attempting full rewrite.") # Create new DataFrame and Excel file table = [{col: row.get(col, "") for col in columns} for row in rows] df = pd.DataFrame(table, columns=columns) metric_cols = [c for c in df.columns if c != "checkpoint"] if metric_cols: df[metric_cols] = df[metric_cols].apply(pd.to_numeric, errors="coerce") # df[metric_cols] = np.round(df[metric_cols], 4) # Calculate comparisons for new file has_baseline = False better_mask = None if "checkpoint" in df.columns and (df["checkpoint"] == f"{model_name}").any() and metric_cols: has_baseline = True raw_idx = df.index[df["checkpoint"] == f"{model_name}"][0] raw_values = df.loc[raw_idx, metric_cols] better_mask = df[metric_cols].gt(raw_values).fillna(False) dataset_to_cols = {} for col in metric_cols: dataset_name = col.split("_", 1)[0] dataset_to_cols.setdefault(dataset_name, []).append(col) better_datasets_counts = [] for idx, row_bool in better_mask.iterrows(): count = 0 for ds, ds_cols in dataset_to_cols.items(): if row_bool[ds_cols].any(): count += 1 better_datasets_counts.append(count) df["better_datasets_than_raw"] = better_datasets_counts else: df["better_datasets_than_raw"] = 0 final_columns = columns + ["better_datasets_than_raw"] # Write using Pandas/OpenPyxl with pd.ExcelWriter(out_path, engine="openpyxl") as writer: df.to_excel(writer, index=False, sheet_name=sheet_name, columns=final_columns) ws = writer.sheets[sheet_name] green_fill = PatternFill(start_color="90EE90", end_color="90EE90", fill_type="solid") # Formatting for r_idx, row in enumerate(ws.iter_rows(min_row=2), start=0): # Data rows for cell in row: cell.alignment = Alignment(horizontal='center', vertical='center') # Highlight cells better than baseline if has_baseline and better_mask is not None: row_label = df.index[r_idx] for col_name in metric_cols: if col_name in df.columns: col_idx = df.columns.get_loc(col_name) # pandas index # Map to excel column (1-based, adjusted for list) # Actually simpler: find column index in final_columns if col_name in final_columns: excel_col_idx = final_columns.index(col_name) + 1 if bool(better_mask.loc[row_label, col_name]): ws.cell(row=r_idx+2, column=excel_col_idx).fill = green_fill # Highlight best checkpoint chk_val = df.iloc[r_idx]["checkpoint"] if best_checkpoint and _checkpoint_key(str(chk_val)) == _checkpoint_key(best_checkpoint): ws.cell(row=r_idx+2, column=final_columns.index("checkpoint")+1).font = Font(color="008000", bold=True) print(f"Excel created: {out_path} (sheet: {sheet_name})") def main(): parser = argparse.ArgumentParser( description="Create a metrics table from checkpoint JSON files." ) parser.add_argument( "--root", type=str, default="./SFT/Evaluation/", help="Root directory containing checkpoints (default: checkpoints)", ) parser.add_argument( "--out_filename", type=str, default="./SFT/Evaluation//metrics_summary.xlsx", help="Output Excel file path (default: metrics_summary.xlsx)", ) parser.add_argument( "--run", type=str, default="SFT_dt12.11.19:13_e6_unsloth_Qwen2.5_14B_Instruct_bnb_4bit_bnb_4bit_lr5e-06_t0.0_r64_b4_SFT_Implementation", help="Name of the Excel sheet (subsheet) to write results into.", ) parser.add_argument( "--best_checkpoint", type=str, default=None, help="Name of checkpoint whose name in the first column will be colored green.", ) parser.add_argument( "--base_model_name", type=str, default="qwen2.5-3B", help="Name of the base model we trained on", ) parser.add_argument( "--base_result_dir", type=str, default="./SFT/results_sft_14b", help="Directory of the base model we trained on", ) parser.add_argument( "--train_data", type=str, default="UniADILR", ) args = parser.parse_args() # --- Directory Setup --- # Ensure the table_result directory exists output_dir = "table_result" os.makedirs(output_dir, exist_ok=True) # Construct final path final_out_path = os.path.join(output_dir, os.path.basename(args.out_filename)) # --- Best Checkpoint Auto-Detection --- if args.best_checkpoint is None: BASE_RESULTS_DIR = os.path.expanduser(args.base_result_dir) TRAINING_DIR = os.path.join(BASE_RESULTS_DIR, f"Training_{args.run}") FINAL_DIR = os.path.join(BASE_RESULTS_DIR, args.run) TRAINING_BASE = None if os.path.isdir(TRAINING_DIR): TRAINING_BASE = TRAINING_DIR elif os.path.isdir(FINAL_DIR): TRAINING_BASE = FINAL_DIR if TRAINING_BASE: best_path, _ = find_best_checkpoint(TRAINING_BASE) if best_path: args.best_checkpoint = os.path.basename(best_path) # --- Data Collection --- print(f"Collecting metrics for run: {args.run}") rows, columns = collect_all_rows(args.root, args.run, args.best_checkpoint, args.base_model_name) if not rows: print("No metrics found. Exiting.") return # --- Sheet Naming Logic --- # Attempt to format the sheet name based on the specific pattern provided in the prompt # Fallback to args.run if the split/format fails try: parts = args.run.split("e20_") if len(parts) >= 2: sheet_name = parts[0] + args.train_data + "_e20_" + parts[1] else: sheet_name = args.run except Exception: sheet_name = args.run old_sheet_name = args.run # --- Save to Local Excel --- write_excel( rows=rows, columns=columns, out_path=final_out_path, sheet_name=sheet_name, old_sheet_name=old_sheet_name, best_checkpoint=args.best_checkpoint, model_name=args.base_model_name, ) if __name__ == "__main__": main()