AniFileBERT / tools /recover_training_context.py
ModerRAS's picture
chore: checkpoint current training and manual relabel progress
efb213a
raw
history blame
21.5 kB
"""Summarize AniFileBERT training context from local and optional remote workers.
This helper is designed for "session lost" recovery:
- find the most recent run under `checkpoints/`
- inspect latest resumable checkpoint and final artifacts
- tail the run log (`logs/<run>/combined.log`) when available
- optionally probe a remote Windows worker over SSH
Examples
--------
Local latest run:
uv run python -m tools.recover_training_context
Specific run:
uv run python -m tools.recover_training_context --run dmhy-char-virtual-foo
Include remote worker:
uv run python -m tools.recover_training_context \
--remote-host adqew@192.168.63.157 \
--remote-repo "C:\\WorkSpace\\Python\\AniFileBERT"
"""
from __future__ import annotations
import argparse
import json
import re
import subprocess
import sys
from collections import deque
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
DEFAULT_REMOTE_REPO = r"C:\WorkSpace\Python\AniFileBERT"
CHECKPOINT_PATTERN = re.compile(r"^checkpoint-(\d+)$")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--repo", default=".", help="Local AniFileBERT repository path")
parser.add_argument("--run", default=None, help="Run name under checkpoints/ (default: latest)")
parser.add_argument("--tail", type=int, default=80, help="Tail lines from combined.log")
parser.add_argument("--remote-host", default=None, help="Optional SSH host (e.g. adqew@192.168.63.157)")
parser.add_argument("--remote-repo", default=DEFAULT_REMOTE_REPO, help="Remote repository path")
parser.add_argument("--timeout", type=int, default=45, help="Remote probe timeout seconds")
parser.add_argument("--format", choices=["text", "json"], default="text", help="Output format")
parser.add_argument("--output", default=None, help="Optional output file path (.json or .txt)")
return parser.parse_args()
def iso_ts(value: float) -> str:
return datetime.fromtimestamp(value).strftime("%Y-%m-%d %H:%M:%S")
def tail_file(path: Path, lines: int) -> list[str]:
if not path.exists():
return []
q: deque[str] = deque(maxlen=max(lines, 1))
with path.open("r", encoding="utf-8", errors="replace") as handle:
for line in handle:
q.append(line.rstrip("\r\n"))
return list(q)
def read_json(path: Path) -> Optional[dict[str, Any]]:
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return None
def parse_checkpoint_step(name: str) -> Optional[int]:
match = CHECKPOINT_PATTERN.match(name)
if not match:
return None
return int(match.group(1))
def sorted_runs(checkpoints_dir: Path) -> list[Path]:
if not checkpoints_dir.exists():
return []
dirs = [entry for entry in checkpoints_dir.iterdir() if entry.is_dir()]
return sorted(dirs, key=lambda item: item.stat().st_mtime, reverse=True)
def select_run(checkpoints_dir: Path, requested_run: Optional[str]) -> Optional[Path]:
runs = sorted_runs(checkpoints_dir)
if not runs:
return None
if not requested_run:
return runs[0]
candidate = checkpoints_dir / requested_run
if candidate.is_dir():
return candidate
return None
def list_checkpoints(run_dir: Path) -> list[dict[str, Any]]:
entries: list[dict[str, Any]] = []
if not run_dir.exists():
return entries
for entry in run_dir.iterdir():
if not entry.is_dir():
continue
step = parse_checkpoint_step(entry.name)
if step is None:
continue
entries.append(
{
"name": entry.name,
"step": step,
"mtime": iso_ts(entry.stat().st_mtime),
}
)
entries.sort(key=lambda item: item["step"], reverse=True)
return entries
def summarize_case_metrics(path: Path) -> Optional[dict[str, Any]]:
raw = read_json(path)
if not isinstance(raw, dict):
return None
modes = raw.get("modes")
if not isinstance(modes, dict):
return None
summary: dict[str, Any] = {}
for mode_key in ("model_only", "normalized_only"):
mode = modes.get(mode_key)
if not isinstance(mode, dict):
continue
summary[mode_key] = {
"full_accuracy": mode.get("full_accuracy"),
"full_correct": mode.get("full_correct"),
"case_count": mode.get("case_count"),
}
if not summary:
return None
return summary
def inspect_local(repo: Path, requested_run: Optional[str], tail_lines: int) -> dict[str, Any]:
checkpoints_dir = repo / "checkpoints"
run_dir = select_run(checkpoints_dir, requested_run)
result: dict[str, Any] = {
"repository": str(repo.resolve()),
"requested_run": requested_run,
"run_found": run_dir is not None,
}
if run_dir is None:
result["error"] = "run_not_found"
return result
run_name = run_dir.name
final_dir = run_dir / "final"
checkpoints = list_checkpoints(run_dir)
latest_checkpoint = checkpoints[0] if checkpoints else None
run_metadata_path = final_dir / "run_metadata.json"
run_metadata = read_json(run_metadata_path) or {}
trainer_state = None
if latest_checkpoint is not None:
trainer_state = read_json(run_dir / latest_checkpoint["name"] / "trainer_state.json") or {}
logs_dir = repo / "logs" / run_name
combined_log_path = logs_dir / "combined.log"
log_tail = tail_file(combined_log_path, tail_lines)
run_script = (logs_dir / "run.ps1").read_text(encoding="utf-8", errors="replace") if (logs_dir / "run.ps1").exists() else None
case_metrics = summarize_case_metrics(final_dir / "case_metrics.json")
path_case_metrics = summarize_case_metrics(final_dir / "path_prefix_case_metrics.json")
if final_dir.exists():
status = "completed"
elif latest_checkpoint is not None:
status = "checkpointed_no_final"
else:
status = "started_no_checkpoint"
result.update(
{
"run": run_name,
"run_dir": str(run_dir),
"final_dir_exists": final_dir.exists(),
"status": status,
"latest_checkpoint": latest_checkpoint,
"checkpoint_count": len(checkpoints),
"checkpoints": checkpoints[:5],
"run_metadata": {
"experiment_name": run_metadata.get("experiment_name"),
"model_head": run_metadata.get("model_head"),
"tokenizer_variant": run_metadata.get("tokenizer_variant"),
"dataset_mode": run_metadata.get("dataset_mode"),
"virtual_dataset_dir": run_metadata.get("virtual_dataset_dir"),
"train_samples": run_metadata.get("train_samples"),
"epochs": run_metadata.get("epochs"),
"batch_size": run_metadata.get("batch_size"),
"learning_rate": run_metadata.get("learning_rate"),
"seed": run_metadata.get("seed"),
},
"latest_trainer_state": {
"global_step": trainer_state.get("global_step"),
"epoch": trainer_state.get("epoch"),
"best_metric": trainer_state.get("best_metric"),
}
if trainer_state
else None,
"case_metrics_summary": case_metrics,
"path_case_metrics_summary": path_case_metrics,
"log_path": str(combined_log_path) if combined_log_path.exists() else None,
"log_tail": log_tail,
"run_script_path": str(logs_dir / "run.ps1") if run_script is not None else None,
"run_script": run_script,
"resume_hint": build_resume_hint(run_name, latest_checkpoint, status),
}
)
return result
def build_resume_hint(run_name: str, latest_checkpoint: Optional[dict[str, Any]], status: str) -> Optional[str]:
if status == "completed":
return None
checkpoint_part = "auto"
if latest_checkpoint and latest_checkpoint.get("name"):
checkpoint_part = f"checkpoints/{run_name}/{latest_checkpoint['name']}"
return (
"Resume with: .\\.venv\\Scripts\\python.exe -m anifilebert.train "
f"--save-dir checkpoints/{run_name} --resume-from-checkpoint {checkpoint_part} "
"(plus the same training arguments used originally)."
)
def read_json_text(raw_text: str) -> Optional[dict[str, Any]]:
text = raw_text.strip()
if not text:
return None
try:
return json.loads(text)
except json.JSONDecodeError:
return None
def run_ssh(host: str, command: str, timeout_seconds: int) -> subprocess.CompletedProcess[str]:
return subprocess.run(
["ssh", host, command],
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
timeout=max(timeout_seconds, 1),
check=False,
)
def fetch_remote_stdout(host: str, command: str, timeout_seconds: int) -> tuple[Optional[str], Optional[dict[str, Any]]]:
try:
completed = run_ssh(host, command, timeout_seconds)
except FileNotFoundError:
return None, {"error": "ssh_not_found"}
except subprocess.TimeoutExpired:
return None, {"error": f"remote_timeout_{timeout_seconds}s"}
if completed.returncode != 0:
return None, {
"error": "remote_command_failed",
"returncode": completed.returncode,
"stderr": completed.stderr.strip(),
"stdout": completed.stdout.strip(),
}
return completed.stdout, None
def win_join(base: str, *parts: str) -> str:
path = base.rstrip("\\/")
for part in parts:
path += "\\" + part.strip("\\/")
return path
def remote_quote_cmd(value: str) -> str:
return '"' + value.replace('"', r'\"') + '"'
def parse_remote_run_list(raw: str) -> list[str]:
rows = []
for line in raw.splitlines():
text = line.strip()
if text:
rows.append(text)
return rows
def parse_remote_checkpoint_rows(raw: str) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for line in raw.splitlines():
text = line.strip()
if not text:
continue
step = parse_checkpoint_step(text)
if step is None:
continue
rows.append({"name": text, "step": step})
rows.sort(key=lambda item: item["step"], reverse=True)
return rows
def inspect_remote(
host: str,
remote_repo: str,
requested_run: Optional[str],
tail_lines: int,
timeout_seconds: int,
) -> dict[str, Any]:
result: dict[str, Any] = {
"repository": remote_repo,
"requested_run": requested_run,
"run_found": False,
}
# 1) Find run list on remote.
run_list_cmd = (
f"cd /d {remote_quote_cmd(remote_repo)} && "
"dir /b /ad /o-d checkpoints"
)
run_list_raw, error = fetch_remote_stdout(host, run_list_cmd, timeout_seconds)
if error is not None:
return error
assert run_list_raw is not None
runs = parse_remote_run_list(run_list_raw)
if not runs:
result["error"] = "run_not_found"
return result
run_name = requested_run if requested_run else runs[0]
if requested_run and requested_run not in runs:
result["error"] = "run_not_found"
return result
result["run_found"] = True
result["run"] = run_name
result["run_dir"] = win_join(remote_repo, "checkpoints", run_name)
# 2) Check final dir.
final_dir = win_join("checkpoints", run_name, "final")
final_check_cmd = (
f"cd /d {remote_quote_cmd(remote_repo)} && "
f'if exist {remote_quote_cmd(win_join(final_dir, "run_metadata.json"))} (echo TRUE) else (echo FALSE)'
)
final_raw, error = fetch_remote_stdout(host, final_check_cmd, timeout_seconds)
if error is not None:
return error
assert final_raw is not None
final_exists = final_raw.strip().upper().endswith("TRUE")
result["final_dir_exists"] = final_exists
# 3) Find checkpoints for the selected run.
checkpoint_list_cmd = (
f"cd /d {remote_quote_cmd(remote_repo)} && "
f"dir /b /ad /o-d {remote_quote_cmd(win_join('checkpoints', run_name, 'checkpoint-*'))}"
)
checkpoint_raw, error = fetch_remote_stdout(host, checkpoint_list_cmd, timeout_seconds)
if error is not None:
# If no checkpoint exists, Windows returns non-zero; treat as empty checkpoint list.
checkpoint_raw = ""
checkpoints = parse_remote_checkpoint_rows(checkpoint_raw or "")
latest_checkpoint = checkpoints[0] if checkpoints else None
result["latest_checkpoint"] = latest_checkpoint
result["checkpoint_count"] = len(checkpoints)
result["checkpoints"] = checkpoints[:5]
if final_exists:
status = "completed"
elif latest_checkpoint is not None:
status = "checkpointed_no_final"
else:
status = "started_no_checkpoint"
result["status"] = status
# 4) Pull run metadata JSON if available.
run_metadata_cmd = (
f"cd /d {remote_quote_cmd(remote_repo)} && "
f"type {remote_quote_cmd(win_join(final_dir, 'run_metadata.json'))}"
)
run_meta_raw, error = fetch_remote_stdout(host, run_metadata_cmd, timeout_seconds)
run_metadata: Optional[dict[str, Any]] = None
if error is None and run_meta_raw is not None:
run_metadata = read_json_text(run_meta_raw)
if run_metadata:
result["run_metadata"] = {
"experiment_name": run_metadata.get("experiment_name"),
"model_head": run_metadata.get("model_head"),
"tokenizer_variant": run_metadata.get("tokenizer_variant"),
"dataset_mode": run_metadata.get("dataset_mode"),
"virtual_dataset_dir": run_metadata.get("virtual_dataset_dir"),
"train_samples": run_metadata.get("train_samples"),
"epochs": run_metadata.get("epochs"),
"batch_size": run_metadata.get("batch_size"),
"learning_rate": run_metadata.get("learning_rate"),
"seed": run_metadata.get("seed"),
}
else:
result["run_metadata"] = None
# 5) Pull latest trainer state when resumable checkpoint exists.
trainer_state = None
if latest_checkpoint is not None:
trainer_cmd = (
f"cd /d {remote_quote_cmd(remote_repo)} && "
f"type {remote_quote_cmd(win_join('checkpoints', run_name, latest_checkpoint['name'], 'trainer_state.json'))}"
)
trainer_raw, error = fetch_remote_stdout(host, trainer_cmd, timeout_seconds)
if error is None and trainer_raw is not None:
trainer_state = read_json_text(trainer_raw)
if trainer_state:
result["latest_trainer_state"] = {
"global_step": trainer_state.get("global_step"),
"epoch": trainer_state.get("epoch"),
"best_metric": trainer_state.get("best_metric"),
}
else:
result["latest_trainer_state"] = None
# 6) Pull run script if present.
run_script_rel = win_join("logs", run_name, "run.ps1")
run_script_cmd = (
f"cd /d {remote_quote_cmd(remote_repo)} && "
f"type {remote_quote_cmd(run_script_rel)}"
)
run_script_raw, error = fetch_remote_stdout(host, run_script_cmd, timeout_seconds)
if error is None and run_script_raw is not None and run_script_raw.strip():
result["run_script_path"] = win_join(remote_repo, run_script_rel)
result["run_script"] = run_script_raw
else:
result["run_script_path"] = None
result["run_script"] = None
# 7) Pull combined log and trim tail locally for stability across shells.
log_tail: list[str] = []
log_rel = win_join("logs", run_name, "combined.log")
log_cmd = (
f"cd /d {remote_quote_cmd(remote_repo)} && "
f"type {remote_quote_cmd(log_rel)}"
)
log_raw, error = fetch_remote_stdout(host, log_cmd, timeout_seconds)
if error is None and log_raw is not None:
all_lines = [line.rstrip("\r\n") for line in log_raw.splitlines()]
if all_lines:
log_tail = all_lines[-max(tail_lines, 1):]
result["log_path"] = win_join(remote_repo, log_rel)
else:
result["log_path"] = None
else:
result["log_path"] = None
result["log_tail"] = log_tail
return result
def format_report_text(report: dict[str, Any]) -> str:
lines: list[str] = []
local = report.get("local", {})
remote = report.get("remote")
lines.append("AniFileBERT Training Recovery Summary")
lines.append("=" * 38)
lines.append(f"generated_at: {report.get('generated_at')}")
lines.append("")
lines.append("[local]")
if local.get("run_found"):
lines.append(f"run: {local.get('run')}")
lines.append(f"status: {local.get('status')}")
lines.append(f"final_dir_exists: {local.get('final_dir_exists')}")
latest = local.get("latest_checkpoint") or {}
if latest:
lines.append(f"latest_checkpoint: {latest.get('name')} (step={latest.get('step')})")
meta = local.get("run_metadata") or {}
if meta.get("experiment_name"):
lines.append(f"experiment_name: {meta.get('experiment_name')}")
if meta.get("model_head"):
lines.append(f"model_head: {meta.get('model_head')}")
if meta.get("dataset_mode"):
lines.append(f"dataset_mode: {meta.get('dataset_mode')}")
case_summary = local.get("case_metrics_summary") or {}
mo = case_summary.get("model_only")
no = case_summary.get("normalized_only")
if mo and mo.get("full_correct") is not None:
lines.append(
f"case_metrics(model_only): {mo.get('full_correct')}/{mo.get('case_count')} ({mo.get('full_accuracy')})"
)
if no and no.get("full_correct") is not None:
lines.append(
f"case_metrics(normalized_only): {no.get('full_correct')}/{no.get('case_count')} ({no.get('full_accuracy')})"
)
path_case_summary = local.get("path_case_metrics_summary") or {}
pno = path_case_summary.get("normalized_only")
if pno and pno.get("full_correct") is not None:
lines.append(
f"path_case_metrics(normalized_only): {pno.get('full_correct')}/{pno.get('case_count')} ({pno.get('full_accuracy')})"
)
if local.get("resume_hint"):
lines.append(f"resume_hint: {local.get('resume_hint')}")
tail = local.get("log_tail") or []
if tail:
lines.append("")
lines.append("local_log_tail:")
lines.extend(tail[-20:])
else:
lines.append(f"error: {local.get('error', 'unknown')}")
if remote is not None:
lines.append("")
lines.append("[remote]")
if remote.get("error"):
lines.append(f"error: {remote.get('error')}")
if remote.get("stderr"):
lines.append(f"stderr: {remote.get('stderr')}")
else:
lines.append(f"run: {remote.get('run')}")
lines.append(f"status: {remote.get('status')}")
latest = remote.get("latest_checkpoint") or {}
if latest:
lines.append(f"latest_checkpoint: {latest.get('name')} (step={latest.get('step')})")
meta = remote.get("run_metadata") or {}
if meta.get("experiment_name"):
lines.append(f"experiment_name: {meta.get('experiment_name')}")
if meta.get("model_head"):
lines.append(f"model_head: {meta.get('model_head')}")
tail = remote.get("log_tail") or []
if tail:
lines.append("")
lines.append("remote_log_tail:")
lines.extend(tail[-20:])
return "\n".join(lines).rstrip() + "\n"
def safe_stdout_write(text: str) -> None:
try:
sys.stdout.write(text)
return
except UnicodeEncodeError:
encoding = getattr(sys.stdout, "encoding", None) or "utf-8"
data = text.encode(encoding, errors="replace")
buffer = getattr(sys.stdout, "buffer", None)
if buffer is not None:
buffer.write(data)
else:
sys.stdout.write(data.decode(encoding, errors="replace"))
def main() -> None:
args = parse_args()
repo = Path(args.repo).resolve()
report: dict[str, Any] = {
"generated_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
"local": inspect_local(repo, args.run, args.tail),
}
if args.remote_host:
report["remote"] = inspect_remote(
host=args.remote_host,
remote_repo=args.remote_repo,
requested_run=args.run,
tail_lines=args.tail,
timeout_seconds=args.timeout,
)
if args.format == "json":
text = json.dumps(report, ensure_ascii=False, indent=2)
else:
text = format_report_text(report)
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(text, encoding="utf-8")
safe_stdout_write(text)
if __name__ == "__main__":
main()