"""Summarize AniFileBERT training context from local and optional remote workers. This helper is designed for "session lost" recovery: - find the most recent run under `checkpoints/` - inspect latest resumable checkpoint and final artifacts - tail the run log (`logs//combined.log`) when available - optionally probe a remote Windows worker over SSH Examples -------- Local latest run: uv run python -m tools.recover_training_context Specific run: uv run python -m tools.recover_training_context --run dmhy-char-virtual-foo Include remote worker: uv run python -m tools.recover_training_context \ --remote-host adqew@192.168.63.157 \ --remote-repo "C:\\WorkSpace\\Python\\AniFileBERT" """ from __future__ import annotations import argparse import json import re import subprocess import sys from collections import deque from datetime import datetime, timezone from pathlib import Path from typing import Any, Optional DEFAULT_REMOTE_REPO = r"C:\WorkSpace\Python\AniFileBERT" CHECKPOINT_PATTERN = re.compile(r"^checkpoint-(\d+)$") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--repo", default=".", help="Local AniFileBERT repository path") parser.add_argument("--run", default=None, help="Run name under checkpoints/ (default: latest)") parser.add_argument("--tail", type=int, default=80, help="Tail lines from combined.log") parser.add_argument("--remote-host", default=None, help="Optional SSH host (e.g. adqew@192.168.63.157)") parser.add_argument("--remote-repo", default=DEFAULT_REMOTE_REPO, help="Remote repository path") parser.add_argument("--timeout", type=int, default=45, help="Remote probe timeout seconds") parser.add_argument("--format", choices=["text", "json"], default="text", help="Output format") parser.add_argument("--output", default=None, help="Optional output file path (.json or .txt)") return parser.parse_args() def iso_ts(value: float) -> str: return datetime.fromtimestamp(value).strftime("%Y-%m-%d %H:%M:%S") def tail_file(path: Path, lines: int) -> list[str]: if not path.exists(): return [] q: deque[str] = deque(maxlen=max(lines, 1)) with path.open("r", encoding="utf-8", errors="replace") as handle: for line in handle: q.append(line.rstrip("\r\n")) return list(q) def read_json(path: Path) -> Optional[dict[str, Any]]: if not path.exists(): return None try: return json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError: return None def parse_checkpoint_step(name: str) -> Optional[int]: match = CHECKPOINT_PATTERN.match(name) if not match: return None return int(match.group(1)) def sorted_runs(checkpoints_dir: Path) -> list[Path]: if not checkpoints_dir.exists(): return [] dirs = [entry for entry in checkpoints_dir.iterdir() if entry.is_dir()] return sorted(dirs, key=lambda item: item.stat().st_mtime, reverse=True) def select_run(checkpoints_dir: Path, requested_run: Optional[str]) -> Optional[Path]: runs = sorted_runs(checkpoints_dir) if not runs: return None if not requested_run: return runs[0] candidate = checkpoints_dir / requested_run if candidate.is_dir(): return candidate return None def list_checkpoints(run_dir: Path) -> list[dict[str, Any]]: entries: list[dict[str, Any]] = [] if not run_dir.exists(): return entries for entry in run_dir.iterdir(): if not entry.is_dir(): continue step = parse_checkpoint_step(entry.name) if step is None: continue entries.append( { "name": entry.name, "step": step, "mtime": iso_ts(entry.stat().st_mtime), } ) entries.sort(key=lambda item: item["step"], reverse=True) return entries def summarize_case_metrics(path: Path) -> Optional[dict[str, Any]]: raw = read_json(path) if not isinstance(raw, dict): return None modes = raw.get("modes") if not isinstance(modes, dict): return None summary: dict[str, Any] = {} for mode_key in ("model_only", "normalized_only"): mode = modes.get(mode_key) if not isinstance(mode, dict): continue summary[mode_key] = { "full_accuracy": mode.get("full_accuracy"), "full_correct": mode.get("full_correct"), "case_count": mode.get("case_count"), } if not summary: return None return summary def inspect_local(repo: Path, requested_run: Optional[str], tail_lines: int) -> dict[str, Any]: checkpoints_dir = repo / "checkpoints" run_dir = select_run(checkpoints_dir, requested_run) result: dict[str, Any] = { "repository": str(repo.resolve()), "requested_run": requested_run, "run_found": run_dir is not None, } if run_dir is None: result["error"] = "run_not_found" return result run_name = run_dir.name final_dir = run_dir / "final" checkpoints = list_checkpoints(run_dir) latest_checkpoint = checkpoints[0] if checkpoints else None run_metadata_path = final_dir / "run_metadata.json" run_metadata = read_json(run_metadata_path) or {} trainer_state = None if latest_checkpoint is not None: trainer_state = read_json(run_dir / latest_checkpoint["name"] / "trainer_state.json") or {} logs_dir = repo / "logs" / run_name combined_log_path = logs_dir / "combined.log" log_tail = tail_file(combined_log_path, tail_lines) run_script = (logs_dir / "run.ps1").read_text(encoding="utf-8", errors="replace") if (logs_dir / "run.ps1").exists() else None case_metrics = summarize_case_metrics(final_dir / "case_metrics.json") path_case_metrics = summarize_case_metrics(final_dir / "path_prefix_case_metrics.json") if final_dir.exists(): status = "completed" elif latest_checkpoint is not None: status = "checkpointed_no_final" else: status = "started_no_checkpoint" result.update( { "run": run_name, "run_dir": str(run_dir), "final_dir_exists": final_dir.exists(), "status": status, "latest_checkpoint": latest_checkpoint, "checkpoint_count": len(checkpoints), "checkpoints": checkpoints[:5], "run_metadata": { "experiment_name": run_metadata.get("experiment_name"), "model_head": run_metadata.get("model_head"), "tokenizer_variant": run_metadata.get("tokenizer_variant"), "dataset_mode": run_metadata.get("dataset_mode"), "virtual_dataset_dir": run_metadata.get("virtual_dataset_dir"), "train_samples": run_metadata.get("train_samples"), "epochs": run_metadata.get("epochs"), "batch_size": run_metadata.get("batch_size"), "learning_rate": run_metadata.get("learning_rate"), "seed": run_metadata.get("seed"), }, "latest_trainer_state": { "global_step": trainer_state.get("global_step"), "epoch": trainer_state.get("epoch"), "best_metric": trainer_state.get("best_metric"), } if trainer_state else None, "case_metrics_summary": case_metrics, "path_case_metrics_summary": path_case_metrics, "log_path": str(combined_log_path) if combined_log_path.exists() else None, "log_tail": log_tail, "run_script_path": str(logs_dir / "run.ps1") if run_script is not None else None, "run_script": run_script, "resume_hint": build_resume_hint(run_name, latest_checkpoint, status), } ) return result def build_resume_hint(run_name: str, latest_checkpoint: Optional[dict[str, Any]], status: str) -> Optional[str]: if status == "completed": return None checkpoint_part = "auto" if latest_checkpoint and latest_checkpoint.get("name"): checkpoint_part = f"checkpoints/{run_name}/{latest_checkpoint['name']}" return ( "Resume with: .\\.venv\\Scripts\\python.exe -m anifilebert.train " f"--save-dir checkpoints/{run_name} --resume-from-checkpoint {checkpoint_part} " "(plus the same training arguments used originally)." ) def read_json_text(raw_text: str) -> Optional[dict[str, Any]]: text = raw_text.strip() if not text: return None try: return json.loads(text) except json.JSONDecodeError: return None def run_ssh(host: str, command: str, timeout_seconds: int) -> subprocess.CompletedProcess[str]: return subprocess.run( ["ssh", host, command], capture_output=True, text=True, encoding="utf-8", errors="replace", timeout=max(timeout_seconds, 1), check=False, ) def fetch_remote_stdout(host: str, command: str, timeout_seconds: int) -> tuple[Optional[str], Optional[dict[str, Any]]]: try: completed = run_ssh(host, command, timeout_seconds) except FileNotFoundError: return None, {"error": "ssh_not_found"} except subprocess.TimeoutExpired: return None, {"error": f"remote_timeout_{timeout_seconds}s"} if completed.returncode != 0: return None, { "error": "remote_command_failed", "returncode": completed.returncode, "stderr": completed.stderr.strip(), "stdout": completed.stdout.strip(), } return completed.stdout, None def win_join(base: str, *parts: str) -> str: path = base.rstrip("\\/") for part in parts: path += "\\" + part.strip("\\/") return path def remote_quote_cmd(value: str) -> str: return '"' + value.replace('"', r'\"') + '"' def parse_remote_run_list(raw: str) -> list[str]: rows = [] for line in raw.splitlines(): text = line.strip() if text: rows.append(text) return rows def parse_remote_checkpoint_rows(raw: str) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for line in raw.splitlines(): text = line.strip() if not text: continue step = parse_checkpoint_step(text) if step is None: continue rows.append({"name": text, "step": step}) rows.sort(key=lambda item: item["step"], reverse=True) return rows def inspect_remote( host: str, remote_repo: str, requested_run: Optional[str], tail_lines: int, timeout_seconds: int, ) -> dict[str, Any]: result: dict[str, Any] = { "repository": remote_repo, "requested_run": requested_run, "run_found": False, } # 1) Find run list on remote. run_list_cmd = ( f"cd /d {remote_quote_cmd(remote_repo)} && " "dir /b /ad /o-d checkpoints" ) run_list_raw, error = fetch_remote_stdout(host, run_list_cmd, timeout_seconds) if error is not None: return error assert run_list_raw is not None runs = parse_remote_run_list(run_list_raw) if not runs: result["error"] = "run_not_found" return result run_name = requested_run if requested_run else runs[0] if requested_run and requested_run not in runs: result["error"] = "run_not_found" return result result["run_found"] = True result["run"] = run_name result["run_dir"] = win_join(remote_repo, "checkpoints", run_name) # 2) Check final dir. final_dir = win_join("checkpoints", run_name, "final") final_check_cmd = ( f"cd /d {remote_quote_cmd(remote_repo)} && " f'if exist {remote_quote_cmd(win_join(final_dir, "run_metadata.json"))} (echo TRUE) else (echo FALSE)' ) final_raw, error = fetch_remote_stdout(host, final_check_cmd, timeout_seconds) if error is not None: return error assert final_raw is not None final_exists = final_raw.strip().upper().endswith("TRUE") result["final_dir_exists"] = final_exists # 3) Find checkpoints for the selected run. checkpoint_list_cmd = ( f"cd /d {remote_quote_cmd(remote_repo)} && " f"dir /b /ad /o-d {remote_quote_cmd(win_join('checkpoints', run_name, 'checkpoint-*'))}" ) checkpoint_raw, error = fetch_remote_stdout(host, checkpoint_list_cmd, timeout_seconds) if error is not None: # If no checkpoint exists, Windows returns non-zero; treat as empty checkpoint list. checkpoint_raw = "" checkpoints = parse_remote_checkpoint_rows(checkpoint_raw or "") latest_checkpoint = checkpoints[0] if checkpoints else None result["latest_checkpoint"] = latest_checkpoint result["checkpoint_count"] = len(checkpoints) result["checkpoints"] = checkpoints[:5] if final_exists: status = "completed" elif latest_checkpoint is not None: status = "checkpointed_no_final" else: status = "started_no_checkpoint" result["status"] = status # 4) Pull run metadata JSON if available. run_metadata_cmd = ( f"cd /d {remote_quote_cmd(remote_repo)} && " f"type {remote_quote_cmd(win_join(final_dir, 'run_metadata.json'))}" ) run_meta_raw, error = fetch_remote_stdout(host, run_metadata_cmd, timeout_seconds) run_metadata: Optional[dict[str, Any]] = None if error is None and run_meta_raw is not None: run_metadata = read_json_text(run_meta_raw) if run_metadata: result["run_metadata"] = { "experiment_name": run_metadata.get("experiment_name"), "model_head": run_metadata.get("model_head"), "tokenizer_variant": run_metadata.get("tokenizer_variant"), "dataset_mode": run_metadata.get("dataset_mode"), "virtual_dataset_dir": run_metadata.get("virtual_dataset_dir"), "train_samples": run_metadata.get("train_samples"), "epochs": run_metadata.get("epochs"), "batch_size": run_metadata.get("batch_size"), "learning_rate": run_metadata.get("learning_rate"), "seed": run_metadata.get("seed"), } else: result["run_metadata"] = None # 5) Pull latest trainer state when resumable checkpoint exists. trainer_state = None if latest_checkpoint is not None: trainer_cmd = ( f"cd /d {remote_quote_cmd(remote_repo)} && " f"type {remote_quote_cmd(win_join('checkpoints', run_name, latest_checkpoint['name'], 'trainer_state.json'))}" ) trainer_raw, error = fetch_remote_stdout(host, trainer_cmd, timeout_seconds) if error is None and trainer_raw is not None: trainer_state = read_json_text(trainer_raw) if trainer_state: result["latest_trainer_state"] = { "global_step": trainer_state.get("global_step"), "epoch": trainer_state.get("epoch"), "best_metric": trainer_state.get("best_metric"), } else: result["latest_trainer_state"] = None # 6) Pull run script if present. run_script_rel = win_join("logs", run_name, "run.ps1") run_script_cmd = ( f"cd /d {remote_quote_cmd(remote_repo)} && " f"type {remote_quote_cmd(run_script_rel)}" ) run_script_raw, error = fetch_remote_stdout(host, run_script_cmd, timeout_seconds) if error is None and run_script_raw is not None and run_script_raw.strip(): result["run_script_path"] = win_join(remote_repo, run_script_rel) result["run_script"] = run_script_raw else: result["run_script_path"] = None result["run_script"] = None # 7) Pull combined log and trim tail locally for stability across shells. log_tail: list[str] = [] log_rel = win_join("logs", run_name, "combined.log") log_cmd = ( f"cd /d {remote_quote_cmd(remote_repo)} && " f"type {remote_quote_cmd(log_rel)}" ) log_raw, error = fetch_remote_stdout(host, log_cmd, timeout_seconds) if error is None and log_raw is not None: all_lines = [line.rstrip("\r\n") for line in log_raw.splitlines()] if all_lines: log_tail = all_lines[-max(tail_lines, 1):] result["log_path"] = win_join(remote_repo, log_rel) else: result["log_path"] = None else: result["log_path"] = None result["log_tail"] = log_tail return result def format_report_text(report: dict[str, Any]) -> str: lines: list[str] = [] local = report.get("local", {}) remote = report.get("remote") lines.append("AniFileBERT Training Recovery Summary") lines.append("=" * 38) lines.append(f"generated_at: {report.get('generated_at')}") lines.append("") lines.append("[local]") if local.get("run_found"): lines.append(f"run: {local.get('run')}") lines.append(f"status: {local.get('status')}") lines.append(f"final_dir_exists: {local.get('final_dir_exists')}") latest = local.get("latest_checkpoint") or {} if latest: lines.append(f"latest_checkpoint: {latest.get('name')} (step={latest.get('step')})") meta = local.get("run_metadata") or {} if meta.get("experiment_name"): lines.append(f"experiment_name: {meta.get('experiment_name')}") if meta.get("model_head"): lines.append(f"model_head: {meta.get('model_head')}") if meta.get("dataset_mode"): lines.append(f"dataset_mode: {meta.get('dataset_mode')}") case_summary = local.get("case_metrics_summary") or {} mo = case_summary.get("model_only") no = case_summary.get("normalized_only") if mo and mo.get("full_correct") is not None: lines.append( f"case_metrics(model_only): {mo.get('full_correct')}/{mo.get('case_count')} ({mo.get('full_accuracy')})" ) if no and no.get("full_correct") is not None: lines.append( f"case_metrics(normalized_only): {no.get('full_correct')}/{no.get('case_count')} ({no.get('full_accuracy')})" ) path_case_summary = local.get("path_case_metrics_summary") or {} pno = path_case_summary.get("normalized_only") if pno and pno.get("full_correct") is not None: lines.append( f"path_case_metrics(normalized_only): {pno.get('full_correct')}/{pno.get('case_count')} ({pno.get('full_accuracy')})" ) if local.get("resume_hint"): lines.append(f"resume_hint: {local.get('resume_hint')}") tail = local.get("log_tail") or [] if tail: lines.append("") lines.append("local_log_tail:") lines.extend(tail[-20:]) else: lines.append(f"error: {local.get('error', 'unknown')}") if remote is not None: lines.append("") lines.append("[remote]") if remote.get("error"): lines.append(f"error: {remote.get('error')}") if remote.get("stderr"): lines.append(f"stderr: {remote.get('stderr')}") else: lines.append(f"run: {remote.get('run')}") lines.append(f"status: {remote.get('status')}") latest = remote.get("latest_checkpoint") or {} if latest: lines.append(f"latest_checkpoint: {latest.get('name')} (step={latest.get('step')})") meta = remote.get("run_metadata") or {} if meta.get("experiment_name"): lines.append(f"experiment_name: {meta.get('experiment_name')}") if meta.get("model_head"): lines.append(f"model_head: {meta.get('model_head')}") tail = remote.get("log_tail") or [] if tail: lines.append("") lines.append("remote_log_tail:") lines.extend(tail[-20:]) return "\n".join(lines).rstrip() + "\n" def safe_stdout_write(text: str) -> None: try: sys.stdout.write(text) return except UnicodeEncodeError: encoding = getattr(sys.stdout, "encoding", None) or "utf-8" data = text.encode(encoding, errors="replace") buffer = getattr(sys.stdout, "buffer", None) if buffer is not None: buffer.write(data) else: sys.stdout.write(data.decode(encoding, errors="replace")) def main() -> None: args = parse_args() repo = Path(args.repo).resolve() report: dict[str, Any] = { "generated_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), "local": inspect_local(repo, args.run, args.tail), } if args.remote_host: report["remote"] = inspect_remote( host=args.remote_host, remote_repo=args.remote_repo, requested_run=args.run, tail_lines=args.tail, timeout_seconds=args.timeout, ) if args.format == "json": text = json.dumps(report, ensure_ascii=False, indent=2) else: text = format_report_text(report) if args.output: output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(text, encoding="utf-8") safe_stdout_write(text) if __name__ == "__main__": main()