SFT_Dataset / Evaluation /create_table.py
Parsagh1383's picture
Upload folder using huggingface_hub
e6fad38 verified
#!/usr/bin/env python3
"""
Aggregate metrics from checkpoint JSONs into a single table.
Saves results locally to the 'table_result' directory.
"""
import os
import json
import argparse
import csv
import numpy as np
from typing import Dict, Any, List, Tuple, Set, Union
import pandas as pd
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font, Alignment
# Attempt to import the specific helper; if not found, we define a dummy placeholder
# to ensure the script doesn't crash if the user only has this single file.
try:
from evaluate_aime_raw_vs_finetuned import find_best_checkpoint
except ImportError:
def find_best_checkpoint(path):
return None, 0
Scalar = Union[int, float, str]
def is_scalar(x: Any) -> bool:
return isinstance(x, (int, float, str))
def load_metrics_from_json(json_path: str) -> Dict[str, Scalar]:
"""Load scalar metrics from all_cases.json."""
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "metrics" in data and isinstance(data["metrics"], dict):
metrics_dict = data["metrics"]
else:
metrics_dict = data
out: Dict[str, Scalar] = {}
if isinstance(metrics_dict, dict):
for k, v in metrics_dict.items():
if is_scalar(v):
out[k] = v
return out
def collect_all_rows(root_dir: str, run: str, best_checkpoint: str = None, model_name: str = "qwen2.5-3B") -> Tuple[List[Dict[str, Scalar]], List[str]]:
"""Walk the checkpoints directory and collect rows + column names."""
rows: List[Dict[str, Scalar]] = []
all_metric_cols: Set[str] = set()
if not os.path.isdir(root_dir):
# Graceful exit or warning if root doesn't exist
print(f"[WARN] Root directory not found: {root_dir}")
return [], ["checkpoint"]
# 1. Collect Raw Model Metrics
ckpt_path = os.path.join(root_dir, "raw_model")
row: Dict[str, Scalar] = {"checkpoint": f"{model_name}"}
if os.path.isdir(ckpt_path):
for dataset_name in sorted(os.listdir(ckpt_path)):
if "neu" in dataset_name.lower():
continue
dataset_path = os.path.join(ckpt_path, dataset_name)
dataset_name = dataset_name.lower()
if not os.path.isdir(dataset_path):
continue
json_path = os.path.join(dataset_path, "raw_results_train_all.json")
if not os.path.isfile(json_path):
continue
try:
metrics = load_metrics_from_json(json_path)
except Exception as e:
print(f"[WARN] Failed to read {json_path}: {e}")
continue
_process_metrics_into_row(row, all_metric_cols, dataset_name, metrics)
# Only add raw model row if we actually found data
if len(row) > 1:
rows.append(row)
# 2. Collect Checkpoint Metrics
root_dir = os.path.join(root_dir, run)
if not os.path.isdir(root_dir):
print(f"[WARN] Run directory not found: {root_dir}")
ordered_cols = ["checkpoint"] + sorted(all_metric_cols)
return rows, ordered_cols
try:
# Filter for directories that look like steps (integers)
training_step = []
for d in os.listdir(root_dir):
if "-" in d:
try:
step_val = int(d.split("-")[-1])
training_step.append(step_val)
except ValueError:
continue
for step in sorted(training_step):
ckpt_name = f"checkpoint-{step}"
ckpt_path = os.path.join(root_dir, ckpt_name)
if not os.path.isdir(ckpt_path):
continue
if best_checkpoint and ckpt_name == best_checkpoint:
row: Dict[str, Scalar] = {"checkpoint": ckpt_name+"(best)"}
else:
row: Dict[str, Scalar] = {"checkpoint": ckpt_name}
for dataset_name in sorted(os.listdir(ckpt_path)):
dataset_path = os.path.join(ckpt_path, dataset_name)
dataset_name = dataset_name.lower()
if not os.path.isdir(dataset_path):
continue
# Try common filenames
json_path = os.path.join(dataset_path, "all_cases.json")
if not os.path.isfile(json_path):
json_path = os.path.join(dataset_path, "all_casses.json") # Typo fallback
if not os.path.isfile(json_path):
continue
try:
metrics = load_metrics_from_json(json_path)
except Exception as e:
print(f"[WARN] Failed to read {json_path}: {e}")
continue
_process_metrics_into_row(row, all_metric_cols, dataset_name, metrics)
rows.append(row)
except Exception as e:
print(f"[ERROR] Error walking directories: {e}")
ordered_cols = ["checkpoint"] + sorted(all_metric_cols)
return rows, ordered_cols
def _process_metrics_into_row(row, all_metric_cols, dataset_name, metrics):
"""Helper to standardize metric naming logic."""
f1_flag = False
possible_col_names = []
for metric_name, metric_value in metrics.items():
col_name = None
if "accuracy" in metric_name or "hamming_accuracy" in metric_name:
col_name = f"{dataset_name}_acc"
elif "macro_f1" in metric_name or "f1_macro" in metric_name or "f1" in metric_name:
col_name = f"{dataset_name}_f1"
f1_flag = True
elif "precision" in metric_name:
col_name = f"{dataset_name}_precision"
elif "recall" in metric_name:
col_name = f"{dataset_name}_recall"
elif "exact_match_accuracy" in metric_name:
col_name = f"{dataset_name}_EM"
if col_name:
possible_col_names.append((col_name, metric_value))
for col_name, metric_value in possible_col_names:
# Priority logic: if F1 exists, prioritize it?
# The original logic seemed to allow both, but checked flags.
# Preserving original logic structure:
if f1_flag and ("f1" in col_name or "_f1" in col_name):
if col_name not in row:
row[col_name] = round(metric_value, 4)
all_metric_cols.add(col_name)
elif not f1_flag:
if col_name not in row:
row[col_name] = round(metric_value, 4)
all_metric_cols.add(col_name)
def clean_sheet_name(name):
"""Ensure Excel sheet name is valid."""
if not name: return "Sheet1"
invalid = ['\\', '/', '*', '?', ':', '[', ']']
for c in invalid:
name = name.replace(c, '_')
return name[:31] # Excel limit
def _checkpoint_key(name: str) -> str:
"""Normalize checkpoint name by stripping '(best)' etc."""
return name.split("(")[0].strip() if name else ""
def append_rows_in_place(
rows: List[Dict[str, Scalar]],
columns: List[str],
out_path: str,
sheet_name: str,
best_checkpoint: str,
model_name: str,
old_sheet_name: str
) -> None:
"""Update an existing Excel sheet in-place."""
sheet_name = clean_sheet_name(sheet_name)
final_columns = columns + ["better_datasets_than_raw"]
metric_cols = [c for c in columns if c != "checkpoint"]
dataset_to_cols: Dict[str, List[str]] = {}
for col in metric_cols:
ds = col.split("_", 1)[0]
dataset_to_cols.setdefault(ds, []).append(col)
baseline = None
for r in rows:
if r.get("checkpoint") == model_name:
baseline = r
break
wb = load_workbook(out_path)
if sheet_name in wb.sheetnames:
ws = wb[sheet_name]
else:
ws = wb.create_sheet(title=sheet_name)
for idx, col_name in enumerate(final_columns, start=1):
ws.cell(row=1, column=idx, value=col_name)
# Re-map headers in case columns changed or are in different order in file
header_map: Dict[str, int] = {}
max_col = ws.max_column
# Read existing headers
for idx, cell in enumerate(ws[1], start=1):
if cell.value is not None:
header_map[str(cell.value)] = idx
# Add new columns if missing
for col_name in final_columns:
if col_name not in header_map:
max_col += 1
ws.cell(row=1, column=max_col, value=col_name)
header_map[col_name] = max_col
col_index = header_map
first_data_row = 2
# Map existing checkpoints to row numbers
existing_row_map: Dict[str, int] = {}
mx_row = ws.max_row
if "checkpoint" in col_index:
chk_col = col_index["checkpoint"]
for r_idx in range(first_data_row, ws.max_row + 1):
val = ws.cell(row=r_idx, column=chk_col).value
if val:
ck = _checkpoint_key(str(val))
existing_row_map[ck] = r_idx
else:
# Should not happen if file was created correctly
chk_col = 1
green_fill = PatternFill(start_color="90EE90", end_color="90EE90", fill_type="solid")
no_fill = PatternFill(fill_type=None)
best_font = Font(color="008000", bold=True)
normal_font = Font(color="000000")
for row_data in rows:
ckpt_val = row_data.get("checkpoint")
if ckpt_val is None:
continue
ck_key = _checkpoint_key(str(ckpt_val))
# Calculate metrics comparison
better_datasets_count = 0
metric_better_flags: Dict[str, bool] = {}
if baseline is not None:
for col in metric_cols:
v = row_data.get(col)
b = baseline.get(col)
improved = False
try:
if v is not None and b is not None:
if float(v) > float(b):
improved = True
except Exception:
improved = False
metric_better_flags[col] = improved
for ds, ds_cols in dataset_to_cols.items():
if any(metric_better_flags.get(c, False) for c in ds_cols):
better_datasets_count += 1
# Determine target row
if ck_key in existing_row_map:
excel_row = existing_row_map[ck_key]
else:
mx_row += 1
excel_row = mx_row
existing_row_map[ck_key] = excel_row
# Write data
for col_name in final_columns:
if col_name not in col_index: continue
idx = col_index[col_name]
cell = ws.cell(row=excel_row, column=idx)
if col_name == "better_datasets_than_raw":
value = better_datasets_count
else:
value = row_data.get(col_name, "")
cell.value = value
cell.alignment = Alignment(horizontal='center', vertical='center')
# Conditional formatting for metrics
if col_name in metric_cols:
if metric_better_flags.get(col_name, False):
cell.fill = green_fill
else:
cell.fill = no_fill
# Highlight best checkpoint name
checkpoint_cell = ws.cell(row=excel_row, column=chk_col)
if best_checkpoint and _checkpoint_key(best_checkpoint) == ck_key:
checkpoint_cell.font = best_font
else:
checkpoint_cell.font = normal_font
# Ensure checkpoint name has (best) suffix if applicable
if best_checkpoint and _checkpoint_key(best_checkpoint) == ck_key:
checkpoint_cell.value = best_checkpoint + "(best)" if "(best)" not in str(checkpoint_cell.value) else checkpoint_cell.value
# Auto-adjust column widths (simple approximation)
for col_name, idx in col_index.items():
col_letter = ws.cell(row=1, column=idx).column_letter
ws.column_dimensions[col_letter].width = max(len(col_name) + 2, 12)
wb.save(out_path)
wb.close()
def write_excel(
rows: List[Dict[str, Scalar]],
columns: List[str],
out_path: str,
sheet_name: str = "Sheet1",
old_sheet_name: str = None,
best_checkpoint: str = None,
model_name: str = "qwen2.5-3B"
) -> None:
"""
Write data to a local Excel file.
Creates file if it doesn't exist, otherwise appends/updates via append_rows_in_place.
"""
sheet_name = clean_sheet_name(sheet_name)
if os.path.exists(out_path):
try:
append_rows_in_place(
rows=rows,
columns=columns,
out_path=out_path,
sheet_name=sheet_name,
best_checkpoint=best_checkpoint,
model_name=model_name,
old_sheet_name=old_sheet_name
)
print(f"Excel updated: {out_path} (sheet: {sheet_name})")
return
except Exception as e:
print(f"[WARN] Failed to update existing Excel: {e}. Attempting full rewrite.")
# Create new DataFrame and Excel file
table = [{col: row.get(col, "") for col in columns} for row in rows]
df = pd.DataFrame(table, columns=columns)
metric_cols = [c for c in df.columns if c != "checkpoint"]
if metric_cols:
df[metric_cols] = df[metric_cols].apply(pd.to_numeric, errors="coerce")
# df[metric_cols] = np.round(df[metric_cols], 4)
# Calculate comparisons for new file
has_baseline = False
better_mask = None
if "checkpoint" in df.columns and (df["checkpoint"] == f"{model_name}").any() and metric_cols:
has_baseline = True
raw_idx = df.index[df["checkpoint"] == f"{model_name}"][0]
raw_values = df.loc[raw_idx, metric_cols]
better_mask = df[metric_cols].gt(raw_values).fillna(False)
dataset_to_cols = {}
for col in metric_cols:
dataset_name = col.split("_", 1)[0]
dataset_to_cols.setdefault(dataset_name, []).append(col)
better_datasets_counts = []
for idx, row_bool in better_mask.iterrows():
count = 0
for ds, ds_cols in dataset_to_cols.items():
if row_bool[ds_cols].any():
count += 1
better_datasets_counts.append(count)
df["better_datasets_than_raw"] = better_datasets_counts
else:
df["better_datasets_than_raw"] = 0
final_columns = columns + ["better_datasets_than_raw"]
# Write using Pandas/OpenPyxl
with pd.ExcelWriter(out_path, engine="openpyxl") as writer:
df.to_excel(writer, index=False, sheet_name=sheet_name, columns=final_columns)
ws = writer.sheets[sheet_name]
green_fill = PatternFill(start_color="90EE90", end_color="90EE90", fill_type="solid")
# Formatting
for r_idx, row in enumerate(ws.iter_rows(min_row=2), start=0): # Data rows
for cell in row:
cell.alignment = Alignment(horizontal='center', vertical='center')
# Highlight cells better than baseline
if has_baseline and better_mask is not None:
row_label = df.index[r_idx]
for col_name in metric_cols:
if col_name in df.columns:
col_idx = df.columns.get_loc(col_name) # pandas index
# Map to excel column (1-based, adjusted for list)
# Actually simpler: find column index in final_columns
if col_name in final_columns:
excel_col_idx = final_columns.index(col_name) + 1
if bool(better_mask.loc[row_label, col_name]):
ws.cell(row=r_idx+2, column=excel_col_idx).fill = green_fill
# Highlight best checkpoint
chk_val = df.iloc[r_idx]["checkpoint"]
if best_checkpoint and _checkpoint_key(str(chk_val)) == _checkpoint_key(best_checkpoint):
ws.cell(row=r_idx+2, column=final_columns.index("checkpoint")+1).font = Font(color="008000", bold=True)
print(f"Excel created: {out_path} (sheet: {sheet_name})")
def main():
parser = argparse.ArgumentParser(
description="Create a metrics table from checkpoint JSON files."
)
parser.add_argument(
"--root",
type=str,
default="./SFT/Evaluation/",
help="Root directory containing checkpoints (default: checkpoints)",
)
parser.add_argument(
"--out_filename",
type=str,
default="./SFT/Evaluation//metrics_summary.xlsx",
help="Output Excel file path (default: metrics_summary.xlsx)",
)
parser.add_argument(
"--run",
type=str,
default="SFT_dt12.11.19:13_e6_unsloth_Qwen2.5_14B_Instruct_bnb_4bit_bnb_4bit_lr5e-06_t0.0_r64_b4_SFT_Implementation",
help="Name of the Excel sheet (subsheet) to write results into.",
)
parser.add_argument(
"--best_checkpoint",
type=str,
default=None,
help="Name of checkpoint whose name in the first column will be colored green.",
)
parser.add_argument(
"--base_model_name",
type=str,
default="qwen2.5-3B",
help="Name of the base model we trained on",
)
parser.add_argument(
"--base_result_dir",
type=str,
default="./SFT/results_sft_14b",
help="Directory of the base model we trained on",
)
parser.add_argument(
"--train_data",
type=str,
default="UniADILR",
)
args = parser.parse_args()
# --- Directory Setup ---
# Ensure the table_result directory exists
output_dir = "table_result"
os.makedirs(output_dir, exist_ok=True)
# Construct final path
final_out_path = os.path.join(output_dir, os.path.basename(args.out_filename))
# --- Best Checkpoint Auto-Detection ---
if args.best_checkpoint is None:
BASE_RESULTS_DIR = os.path.expanduser(args.base_result_dir)
TRAINING_DIR = os.path.join(BASE_RESULTS_DIR, f"Training_{args.run}")
FINAL_DIR = os.path.join(BASE_RESULTS_DIR, args.run)
TRAINING_BASE = None
if os.path.isdir(TRAINING_DIR):
TRAINING_BASE = TRAINING_DIR
elif os.path.isdir(FINAL_DIR):
TRAINING_BASE = FINAL_DIR
if TRAINING_BASE:
best_path, _ = find_best_checkpoint(TRAINING_BASE)
if best_path:
args.best_checkpoint = os.path.basename(best_path)
# --- Data Collection ---
print(f"Collecting metrics for run: {args.run}")
rows, columns = collect_all_rows(args.root, args.run, args.best_checkpoint, args.base_model_name)
if not rows:
print("No metrics found. Exiting.")
return
# --- Sheet Naming Logic ---
# Attempt to format the sheet name based on the specific pattern provided in the prompt
# Fallback to args.run if the split/format fails
try:
parts = args.run.split("e20_")
if len(parts) >= 2:
sheet_name = parts[0] + args.train_data + "_e20_" + parts[1]
else:
sheet_name = args.run
except Exception:
sheet_name = args.run
old_sheet_name = args.run
# --- Save to Local Excel ---
write_excel(
rows=rows,
columns=columns,
out_path=final_out_path,
sheet_name=sheet_name,
old_sheet_name=old_sheet_name,
best_checkpoint=args.best_checkpoint,
model_name=args.base_model_name,
)
if __name__ == "__main__":
main()