supplymind / scripts /parse_training_logs.py
Rishav
Split dashboard curves by training phase
7579e54
Raw
History Blame Contribute Delete
11.6 kB
"""Parse HF training job logs into the local training dashboard JSON."""
from __future__ import annotations
import argparse
import ast
import json
import re
from datetime import datetime
from pathlib import Path
from statistics import mean
from typing import Any
DEFAULT_OUTPUT = Path("results/training_dashboard.json")
ANSI_RE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
ROLE_RE = re.compile(r"\b(center|warehouse)\b", re.IGNORECASE)
PHASE_RE = re.compile(r"\b(sft|grpo)\b", re.IGNORECASE)
def read_text(path: Path) -> str:
"""Read UTF-8 and PowerShell UTF-16 redirected logs."""
data = path.read_bytes()
if data.startswith((b"\xff\xfe", b"\xfe\xff")):
return data.decode("utf-16")
if data[:200].count(b"\x00") > 20:
return data.decode("utf-16-le")
for encoding in ("utf-8-sig", "utf-16", "utf-16-le", "utf-8", "cp1252"):
try:
return data.decode(encoding)
except UnicodeDecodeError:
continue
return data.decode("utf-8", errors="replace")
def load_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
text = read_text(path).strip()
if not text:
return {}
loaded = json.loads(text)
if not isinstance(loaded, dict):
raise ValueError(f"{path} must contain a JSON object")
return loaded
def to_number(value: Any) -> Any:
if isinstance(value, (int, float)) or value is None:
return value
if not isinstance(value, str):
return value
stripped = value.strip()
if not stripped:
return value
try:
return float(stripped)
except ValueError:
return value
def normalize_mapping(values: dict[str, Any]) -> dict[str, Any]:
return {key: to_number(value) for key, value in values.items()}
def parse_json_line(line: str) -> dict[str, Any] | None:
try:
parsed = json.loads(line)
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def parse_python_dict_line(line: str) -> dict[str, Any] | None:
if not (line.startswith("{") and line.endswith("}")):
return None
try:
parsed = ast.literal_eval(line)
except (SyntaxError, ValueError):
return None
return parsed if isinstance(parsed, dict) else None
def looks_like_step(record: dict[str, Any]) -> bool:
return any(
key in record
for key in (
"loss",
"reward",
"rewards/reward_completions/mean",
"completions/mean_length",
"completions/clipped_ratio",
)
) and "train_runtime" not in record
def infer_role(path: Path, records: list[dict[str, Any]]) -> str:
for record in records:
role = record.get("role")
if isinstance(role, str) and role.lower() in {"center", "warehouse"}:
return role.lower()
match = ROLE_RE.search(path.stem)
if match:
return match.group(1).lower()
raise ValueError(f"Could not infer role for {path}; use --center-log or --warehouse-log")
def infer_phase(path: Path, records: list[dict[str, Any]]) -> str | None:
for record in records:
phase = record.get("phase") or record.get("training_phase") or record.get("stage")
if isinstance(phase, str) and phase.lower() in {"sft", "grpo"}:
return phase.lower()
match = PHASE_RE.search(path.stem)
return match.group(1).lower() if match else None
def series_key(role: str, phase: str | None) -> str:
return f"{role}_{phase}" if phase else role
def parse_log(path: Path, role_hint: str | None = None, phase_hint: str | None = None) -> tuple[str, dict[str, Any]]:
steps: list[dict[str, Any]] = []
reward_batches: list[dict[str, Any]] = []
role_records: list[dict[str, Any]] = []
for raw_line in read_text(path).splitlines():
line = ANSI_RE.sub("", raw_line).strip()
if not line:
continue
json_record = parse_json_line(line)
if json_record:
normalized = normalize_mapping(json_record)
role_records.append(normalized)
if normalized.get("message") == "reward_batch":
reward_batches.append(normalized)
continue
dict_record = parse_python_dict_line(line)
if dict_record and looks_like_step(dict_record):
step = normalize_mapping(dict_record)
step.setdefault("reward", step.get("rewards/reward_completions/mean"))
step.setdefault("completion_length", step.get("completions/mean_length"))
step.setdefault("clipped_ratio", step.get("completions/clipped_ratio"))
step["step"] = len(steps) + 1
steps.append(step)
role = role_hint or infer_role(path, role_records)
phase = phase_hint or infer_phase(path, role_records)
return series_key(role, phase), {
"source_log": str(path),
"role": role,
"phase": phase,
"steps": steps,
"reward_batches": reward_batches,
}
def parse_eval_log(path: Path) -> list[dict[str, Any]]:
comparisons: list[dict[str, Any]] = []
for raw_line in read_text(path).splitlines():
line = ANSI_RE.sub("", raw_line).strip()
if not line:
continue
record = parse_json_line(line)
if not record or record.get("message") != "eval_result":
continue
comparisons.append(normalize_mapping({key: value for key, value in record.items() if key != "episodes"}))
return comparisons
def average(values: list[Any]) -> float | None:
numbers = [value for value in values if isinstance(value, (int, float))]
return mean(numbers) if numbers else None
def summarize(series: dict[str, Any]) -> dict[str, Any]:
steps = series.get("steps", [])
reward_batches = series.get("reward_batches", [])
rewards = [step.get("reward") for step in steps]
losses = [step.get("loss") for step in steps]
clipped = [step.get("completions/clipped_ratio") for step in steps]
lengths = [step.get("completions/mean_length") for step in steps]
summary: dict[str, Any] = {
"steps": len(steps),
"reward_batches": len(reward_batches),
"invalid_payloads": sum(int(batch.get("invalid_payloads") or 0) for batch in reward_batches),
"invalid_actions": sum(int(batch.get("invalid_actions") or 0) for batch in reward_batches),
}
numeric_rewards = [value for value in rewards if isinstance(value, (int, float))]
if numeric_rewards:
summary.update(
{
"first_reward": rewards[0],
"last_reward": rewards[-1],
"best_reward": max(numeric_rewards),
"mean_reward": average(rewards),
}
)
if losses:
summary.update(
{
"first_loss": losses[0],
"last_loss": losses[-1],
"mean_loss": average(losses),
}
)
mean_clipped = average(clipped)
mean_length = average(lengths)
if mean_clipped is not None:
summary["mean_clipped_ratio"] = mean_clipped
if mean_length is not None:
summary["mean_completion_length"] = mean_length
return summary
def discover_logs() -> list[tuple[str | None, str | None, Path]]:
return [(None, None, path) for path in sorted(Path("results").glob("*.log"))]
def build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--center-log", type=Path, help="HF training log for the center role")
parser.add_argument("--warehouse-log", type=Path, help="HF training log for the warehouse role")
parser.add_argument("--center-sft-log", type=Path, help="SFT training log for the center role")
parser.add_argument("--warehouse-sft-log", type=Path, help="SFT training log for the warehouse role")
parser.add_argument("--center-grpo-log", type=Path, help="GRPO training log for the center role")
parser.add_argument("--warehouse-grpo-log", type=Path, help="GRPO training log for the warehouse role")
parser.add_argument("--eval-log", action="append", type=Path, default=[], help="HF eval log containing eval_result JSON rows")
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT, help="Dashboard JSON path")
parser.add_argument("--status", help="Optional top-level dashboard status")
parser.add_argument("--updated-at", help="Optional top-level dashboard updated_at value")
return parser
def main() -> None:
args = build_arg_parser().parse_args()
requested_logs: list[tuple[str | None, str | None, Path]] = []
if args.center_log:
requested_logs.append(("center", None, args.center_log))
if args.warehouse_log:
requested_logs.append(("warehouse", None, args.warehouse_log))
if args.center_sft_log:
requested_logs.append(("center", "sft", args.center_sft_log))
if args.warehouse_sft_log:
requested_logs.append(("warehouse", "sft", args.warehouse_sft_log))
if args.center_grpo_log:
requested_logs.append(("center", "grpo", args.center_grpo_log))
if args.warehouse_grpo_log:
requested_logs.append(("warehouse", "grpo", args.warehouse_grpo_log))
if not requested_logs:
requested_logs = discover_logs()
dashboard = load_json(args.output)
dashboard.setdefault("training_series", {})
dashboard.setdefault("training_summary", {})
parsed_keys: list[str] = []
valid_keys = {"center", "warehouse", "center_sft", "center_grpo", "warehouse_sft", "warehouse_grpo"}
for role_hint, phase_hint, log_path in requested_logs:
if not log_path.exists():
raise FileNotFoundError(log_path)
try:
key, series = parse_log(log_path, role_hint, phase_hint)
except ValueError:
if role_hint is None:
continue
raise
if key not in valid_keys:
continue
if not series.get("steps") and not series.get("reward_batches"):
continue
dashboard["training_series"][key] = series
dashboard["training_summary"][key] = summarize(series)
role = series.get("role")
if key != role and role in {"center", "warehouse"}:
dashboard["training_series"].setdefault(role, series)
dashboard["training_summary"].setdefault(role, summarize(series))
parsed_keys.append(key)
comparisons: list[dict[str, Any]] = []
for eval_log in args.eval_log:
if not eval_log.exists():
raise FileNotFoundError(eval_log)
comparisons.extend(parse_eval_log(eval_log))
if comparisons:
seen: set[tuple[str, str]] = set()
unique = []
for row in reversed(comparisons):
key = (str(row.get("role")), str(row.get("label")))
if key in seen:
continue
seen.add(key)
unique.append(row)
dashboard["comparisons"] = list(reversed(unique))
if args.status:
dashboard["status"] = args.status
if args.updated_at:
dashboard["updated_at"] = args.updated_at
elif parsed_keys and "updated_at" not in dashboard:
dashboard["updated_at"] = datetime.now().astimezone().isoformat(timespec="seconds")
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(dashboard, indent=2) + "\n", encoding="utf-8")
print(f"Wrote {args.output} with series: {', '.join(parsed_keys) or 'none'}")
if __name__ == "__main__":
main()