Spaces:
Sleeping
Sleeping
| """Parse HF training job logs into the local training dashboard JSON.""" | |
| from __future__ import annotations | |
| import argparse | |
| import ast | |
| import json | |
| import re | |
| from datetime import datetime | |
| from pathlib import Path | |
| from statistics import mean | |
| from typing import Any | |
| DEFAULT_OUTPUT = Path("results/training_dashboard.json") | |
| ANSI_RE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]") | |
| ROLE_RE = re.compile(r"\b(center|warehouse)\b", re.IGNORECASE) | |
| PHASE_RE = re.compile(r"\b(sft|grpo)\b", re.IGNORECASE) | |
| def read_text(path: Path) -> str: | |
| """Read UTF-8 and PowerShell UTF-16 redirected logs.""" | |
| data = path.read_bytes() | |
| if data.startswith((b"\xff\xfe", b"\xfe\xff")): | |
| return data.decode("utf-16") | |
| if data[:200].count(b"\x00") > 20: | |
| return data.decode("utf-16-le") | |
| for encoding in ("utf-8-sig", "utf-16", "utf-16-le", "utf-8", "cp1252"): | |
| try: | |
| return data.decode(encoding) | |
| except UnicodeDecodeError: | |
| continue | |
| return data.decode("utf-8", errors="replace") | |
| def load_json(path: Path) -> dict[str, Any]: | |
| if not path.exists(): | |
| return {} | |
| text = read_text(path).strip() | |
| if not text: | |
| return {} | |
| loaded = json.loads(text) | |
| if not isinstance(loaded, dict): | |
| raise ValueError(f"{path} must contain a JSON object") | |
| return loaded | |
| def to_number(value: Any) -> Any: | |
| if isinstance(value, (int, float)) or value is None: | |
| return value | |
| if not isinstance(value, str): | |
| return value | |
| stripped = value.strip() | |
| if not stripped: | |
| return value | |
| try: | |
| return float(stripped) | |
| except ValueError: | |
| return value | |
| def normalize_mapping(values: dict[str, Any]) -> dict[str, Any]: | |
| return {key: to_number(value) for key, value in values.items()} | |
| def parse_json_line(line: str) -> dict[str, Any] | None: | |
| try: | |
| parsed = json.loads(line) | |
| except json.JSONDecodeError: | |
| return None | |
| return parsed if isinstance(parsed, dict) else None | |
| def parse_python_dict_line(line: str) -> dict[str, Any] | None: | |
| if not (line.startswith("{") and line.endswith("}")): | |
| return None | |
| try: | |
| parsed = ast.literal_eval(line) | |
| except (SyntaxError, ValueError): | |
| return None | |
| return parsed if isinstance(parsed, dict) else None | |
| def looks_like_step(record: dict[str, Any]) -> bool: | |
| return any( | |
| key in record | |
| for key in ( | |
| "loss", | |
| "reward", | |
| "rewards/reward_completions/mean", | |
| "completions/mean_length", | |
| "completions/clipped_ratio", | |
| ) | |
| ) and "train_runtime" not in record | |
| def infer_role(path: Path, records: list[dict[str, Any]]) -> str: | |
| for record in records: | |
| role = record.get("role") | |
| if isinstance(role, str) and role.lower() in {"center", "warehouse"}: | |
| return role.lower() | |
| match = ROLE_RE.search(path.stem) | |
| if match: | |
| return match.group(1).lower() | |
| raise ValueError(f"Could not infer role for {path}; use --center-log or --warehouse-log") | |
| def infer_phase(path: Path, records: list[dict[str, Any]]) -> str | None: | |
| for record in records: | |
| phase = record.get("phase") or record.get("training_phase") or record.get("stage") | |
| if isinstance(phase, str) and phase.lower() in {"sft", "grpo"}: | |
| return phase.lower() | |
| match = PHASE_RE.search(path.stem) | |
| return match.group(1).lower() if match else None | |
| def series_key(role: str, phase: str | None) -> str: | |
| return f"{role}_{phase}" if phase else role | |
| def parse_log(path: Path, role_hint: str | None = None, phase_hint: str | None = None) -> tuple[str, dict[str, Any]]: | |
| steps: list[dict[str, Any]] = [] | |
| reward_batches: list[dict[str, Any]] = [] | |
| role_records: list[dict[str, Any]] = [] | |
| for raw_line in read_text(path).splitlines(): | |
| line = ANSI_RE.sub("", raw_line).strip() | |
| if not line: | |
| continue | |
| json_record = parse_json_line(line) | |
| if json_record: | |
| normalized = normalize_mapping(json_record) | |
| role_records.append(normalized) | |
| if normalized.get("message") == "reward_batch": | |
| reward_batches.append(normalized) | |
| continue | |
| dict_record = parse_python_dict_line(line) | |
| if dict_record and looks_like_step(dict_record): | |
| step = normalize_mapping(dict_record) | |
| step.setdefault("reward", step.get("rewards/reward_completions/mean")) | |
| step.setdefault("completion_length", step.get("completions/mean_length")) | |
| step.setdefault("clipped_ratio", step.get("completions/clipped_ratio")) | |
| step["step"] = len(steps) + 1 | |
| steps.append(step) | |
| role = role_hint or infer_role(path, role_records) | |
| phase = phase_hint or infer_phase(path, role_records) | |
| return series_key(role, phase), { | |
| "source_log": str(path), | |
| "role": role, | |
| "phase": phase, | |
| "steps": steps, | |
| "reward_batches": reward_batches, | |
| } | |
| def parse_eval_log(path: Path) -> list[dict[str, Any]]: | |
| comparisons: list[dict[str, Any]] = [] | |
| for raw_line in read_text(path).splitlines(): | |
| line = ANSI_RE.sub("", raw_line).strip() | |
| if not line: | |
| continue | |
| record = parse_json_line(line) | |
| if not record or record.get("message") != "eval_result": | |
| continue | |
| comparisons.append(normalize_mapping({key: value for key, value in record.items() if key != "episodes"})) | |
| return comparisons | |
| def average(values: list[Any]) -> float | None: | |
| numbers = [value for value in values if isinstance(value, (int, float))] | |
| return mean(numbers) if numbers else None | |
| def summarize(series: dict[str, Any]) -> dict[str, Any]: | |
| steps = series.get("steps", []) | |
| reward_batches = series.get("reward_batches", []) | |
| rewards = [step.get("reward") for step in steps] | |
| losses = [step.get("loss") for step in steps] | |
| clipped = [step.get("completions/clipped_ratio") for step in steps] | |
| lengths = [step.get("completions/mean_length") for step in steps] | |
| summary: dict[str, Any] = { | |
| "steps": len(steps), | |
| "reward_batches": len(reward_batches), | |
| "invalid_payloads": sum(int(batch.get("invalid_payloads") or 0) for batch in reward_batches), | |
| "invalid_actions": sum(int(batch.get("invalid_actions") or 0) for batch in reward_batches), | |
| } | |
| numeric_rewards = [value for value in rewards if isinstance(value, (int, float))] | |
| if numeric_rewards: | |
| summary.update( | |
| { | |
| "first_reward": rewards[0], | |
| "last_reward": rewards[-1], | |
| "best_reward": max(numeric_rewards), | |
| "mean_reward": average(rewards), | |
| } | |
| ) | |
| if losses: | |
| summary.update( | |
| { | |
| "first_loss": losses[0], | |
| "last_loss": losses[-1], | |
| "mean_loss": average(losses), | |
| } | |
| ) | |
| mean_clipped = average(clipped) | |
| mean_length = average(lengths) | |
| if mean_clipped is not None: | |
| summary["mean_clipped_ratio"] = mean_clipped | |
| if mean_length is not None: | |
| summary["mean_completion_length"] = mean_length | |
| return summary | |
| def discover_logs() -> list[tuple[str | None, str | None, Path]]: | |
| return [(None, None, path) for path in sorted(Path("results").glob("*.log"))] | |
| def build_arg_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--center-log", type=Path, help="HF training log for the center role") | |
| parser.add_argument("--warehouse-log", type=Path, help="HF training log for the warehouse role") | |
| parser.add_argument("--center-sft-log", type=Path, help="SFT training log for the center role") | |
| parser.add_argument("--warehouse-sft-log", type=Path, help="SFT training log for the warehouse role") | |
| parser.add_argument("--center-grpo-log", type=Path, help="GRPO training log for the center role") | |
| parser.add_argument("--warehouse-grpo-log", type=Path, help="GRPO training log for the warehouse role") | |
| parser.add_argument("--eval-log", action="append", type=Path, default=[], help="HF eval log containing eval_result JSON rows") | |
| parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT, help="Dashboard JSON path") | |
| parser.add_argument("--status", help="Optional top-level dashboard status") | |
| parser.add_argument("--updated-at", help="Optional top-level dashboard updated_at value") | |
| return parser | |
| def main() -> None: | |
| args = build_arg_parser().parse_args() | |
| requested_logs: list[tuple[str | None, str | None, Path]] = [] | |
| if args.center_log: | |
| requested_logs.append(("center", None, args.center_log)) | |
| if args.warehouse_log: | |
| requested_logs.append(("warehouse", None, args.warehouse_log)) | |
| if args.center_sft_log: | |
| requested_logs.append(("center", "sft", args.center_sft_log)) | |
| if args.warehouse_sft_log: | |
| requested_logs.append(("warehouse", "sft", args.warehouse_sft_log)) | |
| if args.center_grpo_log: | |
| requested_logs.append(("center", "grpo", args.center_grpo_log)) | |
| if args.warehouse_grpo_log: | |
| requested_logs.append(("warehouse", "grpo", args.warehouse_grpo_log)) | |
| if not requested_logs: | |
| requested_logs = discover_logs() | |
| dashboard = load_json(args.output) | |
| dashboard.setdefault("training_series", {}) | |
| dashboard.setdefault("training_summary", {}) | |
| parsed_keys: list[str] = [] | |
| valid_keys = {"center", "warehouse", "center_sft", "center_grpo", "warehouse_sft", "warehouse_grpo"} | |
| for role_hint, phase_hint, log_path in requested_logs: | |
| if not log_path.exists(): | |
| raise FileNotFoundError(log_path) | |
| try: | |
| key, series = parse_log(log_path, role_hint, phase_hint) | |
| except ValueError: | |
| if role_hint is None: | |
| continue | |
| raise | |
| if key not in valid_keys: | |
| continue | |
| if not series.get("steps") and not series.get("reward_batches"): | |
| continue | |
| dashboard["training_series"][key] = series | |
| dashboard["training_summary"][key] = summarize(series) | |
| role = series.get("role") | |
| if key != role and role in {"center", "warehouse"}: | |
| dashboard["training_series"].setdefault(role, series) | |
| dashboard["training_summary"].setdefault(role, summarize(series)) | |
| parsed_keys.append(key) | |
| comparisons: list[dict[str, Any]] = [] | |
| for eval_log in args.eval_log: | |
| if not eval_log.exists(): | |
| raise FileNotFoundError(eval_log) | |
| comparisons.extend(parse_eval_log(eval_log)) | |
| if comparisons: | |
| seen: set[tuple[str, str]] = set() | |
| unique = [] | |
| for row in reversed(comparisons): | |
| key = (str(row.get("role")), str(row.get("label"))) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| unique.append(row) | |
| dashboard["comparisons"] = list(reversed(unique)) | |
| if args.status: | |
| dashboard["status"] = args.status | |
| if args.updated_at: | |
| dashboard["updated_at"] = args.updated_at | |
| elif parsed_keys and "updated_at" not in dashboard: | |
| dashboard["updated_at"] = datetime.now().astimezone().isoformat(timespec="seconds") | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| args.output.write_text(json.dumps(dashboard, indent=2) + "\n", encoding="utf-8") | |
| print(f"Wrote {args.output} with series: {', '.join(parsed_keys) or 'none'}") | |
| if __name__ == "__main__": | |
| main() | |