math_trainer / app.py
NorthernTribe-Research's picture
Enable autonomous-by-default run profile and auto-apply full execution parameters.
11f9ebf verified
#!/usr/bin/env python3
"""Gradio app to run SOTA conjecture-model training on Hugging Face Space GPU."""
from __future__ import annotations
import datetime as dt
import html
import inspect
import json
import os
import re
import select
import shutil
import signal
import subprocess
import sys
import threading
import time
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple
import gradio as gr
import torch
import yaml
from huggingface_hub import hf_hub_download
ROOT = Path(__file__).resolve().parent
WORKSPACE_DIR = ROOT / "workspace"
DATA_DIR = WORKSPACE_DIR / "data" / "releases" / "v1"
RUNTIME_DIR = WORKSPACE_DIR / "runtime"
CONFIG_TEMPLATE = ROOT / "configs" / "deepseek_math_sota.yaml"
TRAIN_SCRIPT = ROOT / "scripts" / "train_sota.py"
EVAL_SCRIPT = ROOT / "scripts" / "eval_sota.py"
TRAIN_OUTPUT_DIR = WORKSPACE_DIR / "runs" / "math-conjecture-sota"
CREDENTIAL_FILE_CANDIDATES = [
ROOT / "huggingface-api-key.json",
ROOT.parent / "huggingface-api-key.json",
]
REPO_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}/[A-Za-z0-9][A-Za-z0-9._-]{0,95}$")
STAGE_LOG_RE = re.compile(r"\[stage\s+(\d+)\]")
LOSS_LOG_RE = re.compile(r"(?:^|[\s{,'\"])(?:loss|train_loss)\s*[:=]\s*([-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?)")
RUN_STATE_LOCK = threading.Lock()
RUN_IN_PROGRESS = False
CANCEL_REQUESTED = False
ACTIVE_PROCESS: Optional[subprocess.Popen] = None
ACTIVE_RUN_LABEL = ""
TACTICAL_CSS = """
@import url("https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500;600&family=Rajdhani:wght@500;600;700&display=swap");
:root {
--ops-bg: #070707;
--ops-bg-2: #0f0f0f;
--ops-panel: #111111;
--ops-panel-2: #161616;
--ops-border: #2a2a2a;
--ops-border-strong: #3d3d3d;
--ops-text: #ececec;
--ops-muted: #a8a8a8;
--ops-bright: #ffffff;
}
.gradio-container {
color: var(--ops-text) !important;
background:
linear-gradient(rgba(255, 255, 255, 0.02) 1px, transparent 1px),
linear-gradient(90deg, rgba(255, 255, 255, 0.02) 1px, transparent 1px),
radial-gradient(circle at 50% -10%, #1d1d1d 0%, #0f0f0f 38%, #070707 100%) !important;
background-size: 26px 26px, 26px 26px, 100% 100% !important;
font-family: "IBM Plex Mono", "JetBrains Mono", "Fira Code", monospace !important;
}
.gradio-container .prose h1,
.gradio-container .prose h2,
.gradio-container .prose h3,
.gradio-container .prose p,
.gradio-container .prose li,
.gradio-container .prose strong {
color: var(--ops-text) !important;
}
.gradio-container .prose h1,
.gradio-container .prose h2 {
font-family: "Rajdhani", "IBM Plex Mono", monospace !important;
letter-spacing: 0.08em !important;
text-transform: uppercase !important;
}
.gradio-container .prose code {
color: var(--ops-bright) !important;
background: #1b1b1b !important;
border: 1px solid var(--ops-border) !important;
}
.gradio-container .block,
.gradio-container .form {
background: linear-gradient(180deg, var(--ops-panel) 0%, var(--ops-panel-2) 100%) !important;
border: 1px solid var(--ops-border) !important;
box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.03), 0 12px 28px rgba(0, 0, 0, 0.35) !important;
}
.gradio-container label span,
.gradio-container .block-info,
.gradio-container [data-testid="block-info"] {
color: var(--ops-muted) !important;
letter-spacing: 0.12em !important;
text-transform: uppercase !important;
font-size: 0.74rem !important;
}
.gradio-container input,
.gradio-container textarea {
background: #0c0c0c !important;
color: var(--ops-text) !important;
border: 1px solid var(--ops-border-strong) !important;
box-shadow: none !important;
font-family: "IBM Plex Mono", "JetBrains Mono", monospace !important;
}
.gradio-container input::placeholder,
.gradio-container textarea::placeholder {
color: #7f7f7f !important;
}
.gradio-container input:focus,
.gradio-container textarea:focus {
border-color: #656565 !important;
outline: none !important;
}
.gradio-container button {
border: 1px solid #565656 !important;
background: linear-gradient(180deg, #212121 0%, #151515 100%) !important;
color: var(--ops-bright) !important;
letter-spacing: 0.08em !important;
text-transform: uppercase !important;
font-family: "Rajdhani", "IBM Plex Mono", monospace !important;
}
.gradio-container button.primary,
.gradio-container button.stop,
.gradio-container button.secondary {
background: linear-gradient(180deg, #2a2a2a 0%, #171717 100%) !important;
border-color: #686868 !important;
color: #f7f7f7 !important;
}
.gradio-container button:hover {
filter: brightness(1.08);
}
.ops-header {
border: 1px solid var(--ops-border);
background: linear-gradient(180deg, #101010 0%, #0c0c0c 100%);
padding: 12px 14px;
margin: 2px 0 8px 0;
}
.ops-header-title {
font-family: "Rajdhani", "IBM Plex Mono", monospace;
letter-spacing: 0.16em;
text-transform: uppercase;
color: #f4f4f4;
font-weight: 700;
font-size: 1rem;
}
.ops-header-tags {
margin-top: 8px;
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.ops-tag {
border: 1px solid #474747;
background: #181818;
color: #d5d5d5;
padding: 3px 7px;
font-size: 0.72rem;
letter-spacing: 0.12em;
text-transform: uppercase;
}
.ops-visual {
border: 1px solid var(--ops-border);
background: linear-gradient(180deg, #101010 0%, #0b0b0b 100%);
padding: 12px;
}
.ops-visual-head {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 10px;
gap: 10px;
}
.ops-visual-title {
font-family: "Rajdhani", "IBM Plex Mono", monospace;
font-weight: 700;
letter-spacing: 0.14em;
text-transform: uppercase;
color: #f1f1f1;
}
.ops-visual-sub {
color: #9f9f9f;
font-size: 0.78rem;
letter-spacing: 0.08em;
text-transform: uppercase;
}
.ops-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 10px;
}
.ops-card {
border: 1px solid #323232;
background: linear-gradient(180deg, #161616 0%, #101010 100%);
padding: 9px;
min-height: 72px;
}
.ops-k {
color: #9a9a9a;
font-size: 0.68rem;
letter-spacing: 0.11em;
text-transform: uppercase;
}
.ops-v {
color: #f0f0f0;
font-family: "Rajdhani", "IBM Plex Mono", monospace;
font-size: 1.05rem;
margin-top: 5px;
letter-spacing: 0.05em;
}
.ops-v-small {
color: #d1d1d1;
font-size: 0.83rem;
margin-top: 4px;
}
.ops-meter {
margin-top: 8px;
width: 100%;
height: 8px;
border: 1px solid #383838;
background: #111111;
position: relative;
overflow: hidden;
}
.ops-meter-fill {
position: absolute;
left: 0;
top: 0;
bottom: 0;
background: linear-gradient(90deg, #bdbdbd 0%, #f0f0f0 100%);
}
.ops-spark {
margin-top: 8px;
border: 1px solid #343434;
background: #0e0e0e;
padding: 3px;
}
.ops-spark svg {
width: 100%;
height: 74px;
display: block;
}
.ops-foot {
margin-top: 10px;
color: #8f8f8f;
font-size: 0.74rem;
letter-spacing: 0.08em;
text-transform: uppercase;
}
.gradio-container footer,
.gradio-container .built-with,
.gradio-container [data-testid="footer"] {
display: none !important;
}
.nt-footer {
margin-top: 12px;
border: 1px solid #2f2f2f;
background: linear-gradient(180deg, #111111 0%, #0b0b0b 100%);
color: #bcbcbc;
text-align: center;
padding: 10px 12px;
font-size: 0.74rem;
letter-spacing: 0.08em;
text-transform: uppercase;
}
"""
TACTICAL_HEADER_HTML = """
<div class="ops-header">
<div class="ops-header-title">Maths Conjecture Solutions // Training Operations Console</div>
<div class="ops-header-tags">
<span class="ops-tag">Tactical Monochrome</span>
<span class="ops-tag">Controlled Ops</span>
<span class="ops-tag">Staged Curriculum</span>
<span class="ops-tag">Live Telemetry</span>
</div>
</div>
"""
TACTICAL_FOOTER_HTML = """
<div class="nt-footer">© 2026 NorthernTribe Research, Inc. All rights reserved.</div>
"""
PROJECT_DESCRIPTION = """
# Math Conjecture Trainer
This console runs the full training operations lane for the `maths-conjuncture-solutions` project:
An autonomous training operations console for DeepSeek-Math that runs multi-stage curriculum fine-tuning on Space GPU, executes post-training quality evaluation, and publishes only qualified adapters, checkpoints, and run reports to your Hugging Face model repository.
1. Pull released parquet splits from `NorthernTribe-Research/math-conjecture-training-corpus`.
2. Build runtime training configuration from `configs/deepseek_math_sota.yaml`.
3. Execute multi-stage DeepSeek-Math curriculum fine-tuning via `scripts/train_sota.py`.
4. Run post-training evaluation with pass@k-style sampling and family-level metrics.
5. Enforce autonomous quality gates before adapter promotion/push.
6. Stream live terminal telemetry, tactical visualization, and structured run summaries.
Autonomous Mode is enabled by default and applies full-stage execution parameters automatically.
"""
def _safe_float(value: Any, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _safe_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def load_template_defaults() -> Dict[str, Any]:
if not CONFIG_TEMPLATE.exists():
return {}
try:
cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
except Exception:
return {}
if not isinstance(cfg, dict):
return {}
return cfg
TEMPLATE_CFG = load_template_defaults()
TEMPLATE_STAGE_COUNT = max(1, len(TEMPLATE_CFG.get("stages", []) or [None]))
TEMPLATE_QUALITY_GATE = TEMPLATE_CFG.get("quality_gate", {})
if not isinstance(TEMPLATE_QUALITY_GATE, dict):
TEMPLATE_QUALITY_GATE = {}
TEMPLATE_POST_EVAL = TEMPLATE_CFG.get("post_eval", {})
if not isinstance(TEMPLATE_POST_EVAL, dict):
TEMPLATE_POST_EVAL = {}
TEMPLATE_HUB = TEMPLATE_CFG.get("hub", {})
if not isinstance(TEMPLATE_HUB, dict):
TEMPLATE_HUB = {}
_raw_gate_enabled = TEMPLATE_QUALITY_GATE.get("enabled", True)
if isinstance(_raw_gate_enabled, bool):
DEFAULT_GATE_ENABLED = _raw_gate_enabled
else:
DEFAULT_GATE_ENABLED = str(_raw_gate_enabled).strip().lower() in {"1", "true", "yes", "y", "on"}
DEFAULT_GATE_MIN_ROWS = max(1, _safe_int(TEMPLATE_QUALITY_GATE.get("min_evaluated_rows"), 120))
DEFAULT_GATE_MIN_PASS_AT_1 = max(0.0, _safe_float(TEMPLATE_QUALITY_GATE.get("min_pass_at_1"), 0.01))
DEFAULT_GATE_MIN_PASS_AT_K = max(0.0, _safe_float(TEMPLATE_QUALITY_GATE.get("min_pass_at_k"), 0.06))
DEFAULT_AUTO_EVAL_K = max(1, _safe_int(TEMPLATE_POST_EVAL.get("k"), 4))
DEFAULT_AUTO_EVAL_SAMPLES = max(1, _safe_int(TEMPLATE_POST_EVAL.get("max_samples"), 300))
DEFAULT_AUTO_PUSH_TO_HUB = bool(TEMPLATE_HUB.get("push_to_hub", True))
def now_ts() -> str:
return dt.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")
def append_log(lines: List[str], message: str) -> str:
lines.append(f"[{now_ts()}] {message}")
text = "\n".join(lines)
if len(text) > 200_000:
text = text[-200_000:]
return text
def summary_text(summary: Dict[str, Any]) -> str:
if not summary:
return ""
return json.dumps(summary, ensure_ascii=True, indent=2)
def compose_ops_console(log_text: str, summary_json: str) -> str:
payload = (log_text or "").strip()
summary_payload = (summary_json or "").strip()
if summary_payload:
if payload:
payload += "\n\n" + ("-" * 52) + "\nMission Summary (JSON)\n" + summary_payload
else:
payload = "Mission Summary (JSON)\n" + summary_payload
return payload
def _as_dict(value: Any) -> Dict[str, Any]:
return value if isinstance(value, dict) else {}
def _parse_summary_json(text: str) -> Dict[str, Any]:
if not text:
return {}
try:
parsed = json.loads(text)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
def _fmt_pct(value: Any) -> str:
try:
return f"{float(value) * 100:.1f}%"
except (TypeError, ValueError):
return "--"
def _fmt_float(value: Any, digits: int = 3) -> str:
try:
return f"{float(value):.{digits}f}"
except (TypeError, ValueError):
return "--"
def _extract_loss_values(log_text: str, limit: int = 48) -> List[float]:
losses: List[float] = []
for line in log_text.splitlines():
lower = line.lower()
if "eval_loss" in lower:
continue
match = LOSS_LOG_RE.search(lower)
if match is None:
continue
try:
value = float(match.group(1))
except (TypeError, ValueError):
continue
if not (value >= 0.0):
continue
losses.append(value)
if len(losses) > limit:
losses = losses[-limit:]
return losses
def _extract_summary_loss_values(summary: Dict[str, Any], limit: int = 24) -> List[float]:
losses: List[float] = []
training_summary = _as_dict(summary.get("training_summary"))
stages_ran = training_summary.get("stages_ran")
if not isinstance(stages_ran, list):
return losses
for stage in stages_ran:
if not isinstance(stage, dict):
continue
train_metrics = stage.get("train_metrics")
if not isinstance(train_metrics, dict):
continue
value = train_metrics.get("train_loss")
try:
loss = float(value)
except (TypeError, ValueError):
continue
if loss >= 0.0:
losses.append(loss)
if len(losses) > limit:
losses = losses[-limit:]
return losses
def _build_loss_sparkline(losses: List[float]) -> str:
if not losses:
return "<div class='ops-v-small'>No live loss points yet.</div>"
width = 520
height = 74
pad = 5
min_v = min(losses)
max_v = max(losses)
span = max(max_v - min_v, 1e-9)
points: List[str] = []
for idx, value in enumerate(losses):
x = pad + (idx * (width - 2 * pad) / max(1, len(losses) - 1))
y = pad + ((max_v - value) * (height - 2 * pad) / span)
points.append(f"{x:.2f},{y:.2f}")
polyline = " ".join(points)
latest = losses[-1]
return (
f"<div class='ops-v-small'>Latest train loss: <strong>{_fmt_float(latest, 4)}</strong></div>"
"<div class='ops-spark'>"
f"<svg viewBox='0 0 {width} {height}' preserveAspectRatio='none'>"
f"<polyline points='{polyline}' fill='none' stroke='#f0f0f0' stroke-width='2' />"
"</svg>"
"</div>"
)
def _infer_stage_snapshot(summary: Dict[str, Any], log_text: str) -> Dict[str, Any]:
start_stage = max(1, _safe_int(summary.get("start_stage"), 1))
stage_count = max(1, _safe_int(summary.get("max_stages"), TEMPLATE_STAGE_COUNT))
completed = 0
training_summary = _as_dict(summary.get("training_summary"))
stages_ran = training_summary.get("stages_ran")
if isinstance(stages_ran, list):
completed = min(stage_count, len(stages_ran))
active_stage = None
for line in reversed(log_text.splitlines()[-350:]):
match = STAGE_LOG_RE.search(line)
if match:
active_stage = _safe_int(match.group(1), 0)
break
if completed >= stage_count:
progress = 1.0
else:
progress = completed / stage_count
if active_stage and active_stage >= start_stage:
relative_active = (active_stage - start_stage) + 0.35
progress = max(progress, min(1.0, relative_active / stage_count))
return {
"start_stage": start_stage,
"stage_count": stage_count,
"completed": completed,
"active_stage": active_stage,
"progress": max(0.0, min(1.0, progress)),
}
def render_ops_visual(summary: Dict[str, Any], status_text: str, log_text: str) -> str:
safe_summary = _as_dict(summary)
runtime = _as_dict(safe_summary.get("runtime"))
quality_gate = _as_dict(safe_summary.get("quality_gate"))
evaluation = _as_dict(safe_summary.get("evaluation"))
push_report = _as_dict(safe_summary.get("push"))
run_label = html.escape(str(safe_summary.get("run_label") or "not-started"))
status_value = html.escape(status_text or "Idle")
runtime_mode = "GPU READY" if runtime.get("cuda_available") else "CPU FALLBACK"
runtime_mode = html.escape(runtime_mode)
device_count = _safe_int(runtime.get("cuda_device_count"), 0)
gate_enabled = bool(quality_gate.get("enabled"))
gate_passed = quality_gate.get("passed")
if not gate_enabled:
gate_text = "Disabled"
elif gate_passed is True:
gate_text = "Passed"
elif gate_passed is False:
gate_text = "Failed"
else:
gate_text = "Pending"
stage_meta = _infer_stage_snapshot(safe_summary, log_text)
progress_pct = int(stage_meta["progress"] * 100)
active_stage = stage_meta.get("active_stage")
stage_hint = f"active stage {active_stage}" if active_stage else "awaiting stage telemetry"
stage_hint = html.escape(stage_hint)
losses = _extract_loss_values(log_text)
if len(losses) < 2:
summary_losses = _extract_summary_loss_values(safe_summary)
if summary_losses:
losses = summary_losses
sparkline_html = _build_loss_sparkline(losses)
pass_k = _fmt_pct(evaluation.get("pass_at_k"))
pass_1 = _fmt_pct(evaluation.get("pass_at_1"))
exact_k = _fmt_pct(evaluation.get("exact_at_k"))
push_state = "Pending"
if push_report:
requested = bool(push_report.get("requested"))
performed = bool(push_report.get("performed"))
if not requested:
push_state = "Not requested"
elif performed:
push_state = "Published"
else:
push_state = "Blocked"
return f"""
<div class="ops-visual">
<div class="ops-visual-head">
<div class="ops-visual-title">Live Tactical Telemetry</div>
<div class="ops-visual-sub">Monochrome Ops Feed</div>
</div>
<div class="ops-grid">
<div class="ops-card">
<div class="ops-k">Run</div>
<div class="ops-v">{run_label}</div>
<div class="ops-v-small">{status_value}</div>
</div>
<div class="ops-card">
<div class="ops-k">Runtime</div>
<div class="ops-v">{runtime_mode}</div>
<div class="ops-v-small">cuda devices: {device_count}</div>
</div>
<div class="ops-card">
<div class="ops-k">Stage Progress</div>
<div class="ops-v">{stage_meta['completed']} / {stage_meta['stage_count']}</div>
<div class="ops-v-small">{stage_hint}</div>
<div class="ops-meter"><div class="ops-meter-fill" style="width:{progress_pct}%"></div></div>
</div>
<div class="ops-card">
<div class="ops-k">Quality Gate</div>
<div class="ops-v">{html.escape(gate_text)}</div>
<div class="ops-v-small">push: {html.escape(push_state)}</div>
</div>
<div class="ops-card">
<div class="ops-k">Eval pass@k</div>
<div class="ops-v">{pass_k}</div>
<div class="ops-v-small">pass@1 {pass_1} | exact@k {exact_k}</div>
</div>
<div class="ops-card">
<div class="ops-k">Loss Stream</div>
{sparkline_html}
</div>
</div>
<div class="ops-foot">dull tactical theme · black / grey / white · anduril/palantir-inspired operations console</div>
</div>
""".strip()
def _token_from_credentials_file(path: Path) -> Optional[str]:
try:
data = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return None
for key in ("token", "key", "api_key", "hf_token"):
value = data.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
def resolve_hf_token() -> Tuple[Optional[str], str, Optional[str]]:
env_token = (os.environ.get("HF_TOKEN") or "").strip()
if env_token:
return env_token, "env:HF_TOKEN", None
env_hub_token = (os.environ.get("HUGGINGFACE_HUB_TOKEN") or "").strip()
if env_hub_token:
return env_hub_token, "env:HUGGINGFACE_HUB_TOKEN", None
for path in CREDENTIAL_FILE_CANDIDATES:
if path.exists():
file_token = _token_from_credentials_file(path)
if file_token:
return file_token, "credentials_file", str(path)
return None, "none", None
def validate_repo_id(repo_id: str, field_name: str) -> str:
value = (repo_id or "").strip()
if not value:
raise ValueError(f"{field_name} is required.")
if not REPO_ID_RE.match(value):
raise ValueError(
f"{field_name} must look like '<owner>/<repo>' and may only contain letters, digits, '-', '_' and '.'."
)
return value
def ensure_workspace() -> None:
DATA_DIR.mkdir(parents=True, exist_ok=True)
RUNTIME_DIR.mkdir(parents=True, exist_ok=True)
def run_runtime_snapshot() -> Dict[str, Any]:
return {
"python": sys.version.split()[0],
"gradio": getattr(gr, "__version__", "unknown"),
"torch": getattr(torch, "__version__", "unknown"),
"cuda_available": bool(torch.cuda.is_available()),
"cuda_device_count": int(torch.cuda.device_count()) if torch.cuda.is_available() else 0,
}
def begin_run(run_label: str) -> bool:
global RUN_IN_PROGRESS, CANCEL_REQUESTED, ACTIVE_RUN_LABEL
with RUN_STATE_LOCK:
if RUN_IN_PROGRESS:
return False
RUN_IN_PROGRESS = True
CANCEL_REQUESTED = False
ACTIVE_RUN_LABEL = run_label
return True
def finish_run() -> None:
global RUN_IN_PROGRESS, CANCEL_REQUESTED, ACTIVE_PROCESS, ACTIVE_RUN_LABEL
with RUN_STATE_LOCK:
RUN_IN_PROGRESS = False
CANCEL_REQUESTED = False
ACTIVE_PROCESS = None
ACTIVE_RUN_LABEL = ""
def set_active_process(proc: subprocess.Popen) -> None:
global ACTIVE_PROCESS
with RUN_STATE_LOCK:
ACTIVE_PROCESS = proc
def clear_active_process(proc: subprocess.Popen) -> None:
global ACTIVE_PROCESS
with RUN_STATE_LOCK:
if ACTIVE_PROCESS is proc:
ACTIVE_PROCESS = None
def is_cancel_requested() -> bool:
with RUN_STATE_LOCK:
return CANCEL_REQUESTED
def terminate_process_group(proc: subprocess.Popen) -> None:
if proc.poll() is not None:
return
try:
if os.name == "posix":
os.killpg(proc.pid, signal.SIGTERM)
else:
proc.terminate()
proc.wait(timeout=10)
return
except Exception:
pass
try:
if os.name == "posix":
os.killpg(proc.pid, signal.SIGKILL)
else:
proc.kill()
except Exception:
pass
def request_cancel() -> str:
global CANCEL_REQUESTED
with RUN_STATE_LOCK:
if not RUN_IN_PROGRESS:
return "No active run."
CANCEL_REQUESTED = True
proc = ACTIVE_PROCESS
run_label = ACTIVE_RUN_LABEL
if proc is not None and proc.poll() is None:
terminate_process_group(proc)
return f"Cancellation requested for {run_label}. Terminating subprocess."
return f"Cancellation requested for {run_label}."
def download_dataset(
dataset_repo_id: str,
token: Optional[str],
log_lines: List[str],
force_redownload: bool,
) -> Tuple[str, str, str]:
ensure_workspace()
out_files: Dict[str, str] = {}
for split in ("train", "validation", "test"):
if is_cancel_requested():
raise RuntimeError("Run cancelled by user.")
out_path = DATA_DIR / f"{split}.parquet"
if out_path.exists() and not force_redownload:
append_log(log_lines, f"Using cached {split}.parquet at {out_path}")
out_files[split] = str(out_path)
continue
download_attempts = 3
for attempt in range(1, download_attempts + 1):
try:
cached_path = hf_hub_download(
repo_id=dataset_repo_id,
repo_type="dataset",
filename=f"{split}.parquet",
token=token,
force_download=force_redownload,
)
shutil.copy2(cached_path, out_path)
append_log(log_lines, f"Downloaded {split}.parquet to {out_path}")
break
except Exception as exc:
if attempt >= download_attempts:
raise RuntimeError(
f"Failed downloading {split}.parquet after {download_attempts} attempts: {exc}"
) from exc
wait_s = 2 ** attempt
append_log(
log_lines,
f"Download retry {attempt}/{download_attempts - 1} for {split}.parquet after error: {exc}",
)
time.sleep(wait_s)
out_files[split] = str(out_path)
return out_files["train"], out_files["validation"], out_files["test"]
def write_runtime_config(
base_model_id: str,
model_repo_id: str,
train_file: str,
validation_file: str,
test_file: str,
run_eval: bool,
eval_k: int,
eval_samples: int,
push_to_hub: bool,
enforce_quality_gate: bool,
gate_min_pass_at_1: float,
gate_min_pass_at_k: float,
gate_min_rows: int,
) -> Path:
cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
cfg["model"]["base_model"] = base_model_id
cfg["hub"]["repo_id"] = model_repo_id
cfg["hub"]["push_to_hub"] = bool(push_to_hub)
cfg["data"]["default_train_file"] = train_file
cfg["data"]["default_validation_file"] = validation_file
cfg["global"]["output_root"] = str(TRAIN_OUTPUT_DIR)
cfg.setdefault("post_eval", {})
cfg["post_eval"]["enabled"] = bool(run_eval)
cfg["post_eval"]["eval_file"] = test_file
cfg["post_eval"]["k"] = int(eval_k)
cfg["post_eval"]["max_samples"] = int(eval_samples)
cfg["post_eval"]["output_json"] = str(TRAIN_OUTPUT_DIR / "post_eval_report.json")
cfg.setdefault("quality_gate", {})
cfg["quality_gate"]["enabled"] = bool(enforce_quality_gate)
cfg["quality_gate"]["min_evaluated_rows"] = int(gate_min_rows)
cfg["quality_gate"]["min_pass_at_1"] = float(gate_min_pass_at_1)
cfg["quality_gate"]["min_pass_at_k"] = float(gate_min_pass_at_k)
cfg["quality_gate"]["require_post_eval"] = bool(enforce_quality_gate and run_eval)
runtime_path = RUNTIME_DIR / "deepseek_math_sota.runtime.yaml"
runtime_path.write_text(yaml.safe_dump(cfg, sort_keys=False), encoding="utf-8")
return runtime_path
def stream_subprocess(
cmd: List[str],
env: dict,
cwd: Path,
log_lines: List[str],
status_prefix: str,
) -> Generator[Tuple[str, str], None, int]:
append_log(log_lines, f"Running command: {' '.join(cmd)}")
yield "\n".join(log_lines), f"{status_prefix}: running"
proc = subprocess.Popen(
cmd,
cwd=str(cwd),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
start_new_session=True,
)
set_active_process(proc)
cancelled = False
ret = 1
last_heartbeat = time.monotonic()
heartbeat_interval_s = 20.0
assert proc.stdout is not None
try:
while True:
if is_cancel_requested() and proc.poll() is None:
cancelled = True
append_log(log_lines, "Cancellation requested. Terminating subprocess.")
yield "\n".join(log_lines), f"{status_prefix}: cancelling"
terminate_process_group(proc)
try:
ready, _, _ = select.select([proc.stdout], [], [], 0.5)
except (OSError, ValueError):
ready = [proc.stdout]
if ready:
line = proc.stdout.readline()
if line:
line = line.rstrip()
if line:
append_log(log_lines, line)
last_heartbeat = time.monotonic()
yield "\n".join(log_lines), f"{status_prefix}: running"
elif proc.poll() is None:
now = time.monotonic()
if now - last_heartbeat >= heartbeat_interval_s:
append_log(log_lines, f"{status_prefix} heartbeat: process alive, waiting for next log chunk.")
last_heartbeat = now
yield "\n".join(log_lines), f"{status_prefix}: running"
if proc.poll() is not None:
break
for tail_line in proc.stdout.readlines():
tail_line = tail_line.rstrip()
if tail_line:
append_log(log_lines, tail_line)
yield "\n".join(log_lines), f"{status_prefix}: running"
ret = proc.wait()
finally:
clear_active_process(proc)
if cancelled and ret != 0:
append_log(log_lines, f"{status_prefix} cancelled.")
yield "\n".join(log_lines), f"{status_prefix}: cancelled"
else:
append_log(log_lines, f"{status_prefix} finished with exit code {ret}")
yield "\n".join(log_lines), f"{status_prefix}: {'ok' if ret == 0 else 'failed'}"
return ret
def make_copyable_textbox(
label: str,
lines: int,
max_lines: Optional[int] = None,
value: str = "",
interactive: bool = False,
) -> gr.Textbox:
textbox_kwargs: Dict[str, Any] = {
"label": label,
"lines": lines,
"value": value,
"interactive": interactive,
}
if max_lines is not None:
textbox_kwargs["max_lines"] = max_lines
textbox_init_params = inspect.signature(gr.Textbox.__init__).parameters
if "buttons" in textbox_init_params:
textbox_kwargs["buttons"] = ["copy"]
elif "show_copy_button" in textbox_init_params:
textbox_kwargs["show_copy_button"] = True
return gr.Textbox(**textbox_kwargs)
def clear_outputs() -> Tuple[str, str, str]:
return "", "Idle", render_ops_visual({}, "Idle", "")
def cancel_pipeline() -> str:
return request_cancel()
def run_pipeline_core(
dataset_repo_id: str,
model_repo_id: str,
base_model_id: str,
autonomous_mode: bool,
start_stage: int,
max_stages: int,
run_eval: bool,
eval_k: int,
eval_samples: int,
enforce_quality_gate: bool,
gate_min_pass_at_1: float,
gate_min_pass_at_k: float,
gate_min_rows: int,
push_to_hub: bool,
force_redownload: bool,
preflight_only: bool,
) -> Generator[Tuple[str, str, str], None, None]:
log_lines: List[str] = []
summary: Dict[str, Any] = {}
run_label = dt.datetime.utcnow().strftime("run-%Y%m%d-%H%M%S")
if not begin_run(run_label):
append_log(log_lines, "A run is already in progress. Wait for it to finish or click Stop.")
busy_summary = {
"result": "busy",
"message": "A run is already in progress.",
"timestamp_utc": now_ts(),
}
yield "\n".join(log_lines), "Busy", summary_text(busy_summary)
return
try:
token, _, _ = resolve_hf_token()
dataset_repo_id = validate_repo_id(dataset_repo_id, "Dataset repo")
model_repo_id = validate_repo_id(model_repo_id, "Model repo")
base_model_id = (base_model_id or "").strip()
if not base_model_id:
raise ValueError("Base model is required.")
stage_start = int(start_stage)
stage_count = int(max_stages)
eval_k = int(eval_k)
eval_samples = int(eval_samples)
gate_min_rows = int(gate_min_rows)
gate_min_pass_at_1 = float(gate_min_pass_at_1)
gate_min_pass_at_k = float(gate_min_pass_at_k)
if autonomous_mode:
stage_start = 1
stage_count = TEMPLATE_STAGE_COUNT
run_eval = True
eval_k = DEFAULT_AUTO_EVAL_K
eval_samples = DEFAULT_AUTO_EVAL_SAMPLES
enforce_quality_gate = bool(DEFAULT_GATE_ENABLED)
gate_min_rows = DEFAULT_GATE_MIN_ROWS
gate_min_pass_at_1 = DEFAULT_GATE_MIN_PASS_AT_1
gate_min_pass_at_k = DEFAULT_GATE_MIN_PASS_AT_K
push_to_hub = bool(DEFAULT_AUTO_PUSH_TO_HUB)
force_redownload = False
preflight_only = False
if stage_start < 1:
raise ValueError("Start stage must be >= 1.")
if stage_start > TEMPLATE_STAGE_COUNT:
raise ValueError(f"Start stage must be <= {TEMPLATE_STAGE_COUNT}.")
if stage_count < 1:
raise ValueError("How many stages must be >= 1.")
if eval_k < 1:
raise ValueError("Eval K must be >= 1.")
if eval_samples < 1:
raise ValueError("Eval max samples must be >= 1.")
if gate_min_rows < 1:
raise ValueError("Gate minimum rows must be >= 1.")
if not 0.0 <= gate_min_pass_at_1 <= 1.0:
raise ValueError("Gate min pass@1 must be between 0 and 1.")
if not 0.0 <= gate_min_pass_at_k <= 1.0:
raise ValueError("Gate min pass@k must be between 0 and 1.")
for required_path in (CONFIG_TEMPLATE, TRAIN_SCRIPT):
if not required_path.exists():
raise FileNotFoundError(f"Required file is missing: {required_path}")
if run_eval and not EVAL_SCRIPT.exists():
raise FileNotFoundError(f"Evaluation script is missing: {EVAL_SCRIPT}")
runtime = run_runtime_snapshot()
summary.update(
{
"run_label": run_label,
"started_at_utc": now_ts(),
"dataset_repo_id": dataset_repo_id,
"model_repo_id": model_repo_id,
"base_model_id": base_model_id,
"autonomous_mode": bool(autonomous_mode),
"start_stage": stage_start,
"max_stages": stage_count,
"run_eval": bool(run_eval),
"eval_k": eval_k,
"eval_samples": eval_samples,
"enforce_quality_gate": bool(enforce_quality_gate),
"gate_min_rows": gate_min_rows,
"gate_min_pass_at_1": gate_min_pass_at_1,
"gate_min_pass_at_k": gate_min_pass_at_k,
"push_to_hub": bool(push_to_hub),
"force_redownload": bool(force_redownload),
"preflight_only": bool(preflight_only),
"runtime": runtime,
}
)
append_log(log_lines, f"Run {run_label} started.")
if autonomous_mode:
append_log(
log_lines,
"Autonomous mode active: full-stage training/eval/gating/publish profile applied.",
)
append_log(
log_lines,
f"Runtime: python={runtime['python']} gradio={runtime['gradio']} torch={runtime['torch']} "
f"cuda_available={runtime['cuda_available']} devices={runtime['cuda_device_count']}",
)
if token:
append_log(log_lines, "Environment posture validated.")
else:
append_log(log_lines, "Restricted mode active. Hub publish disabled for this run.")
yield "\n".join(log_lines), "Validating environment", summary_text(summary)
if not preflight_only and not torch.cuda.is_available():
summary["compute_mode"] = "cpu_fallback"
append_log(
log_lines,
"GPU is unavailable. Continuing with CPU fallback mode; training will be slower.",
)
yield "\n".join(log_lines), "CPU fallback active", summary_text(summary)
elif torch.cuda.is_available():
summary["compute_mode"] = "gpu"
effective_push_to_hub = bool(push_to_hub)
if effective_push_to_hub and not token:
effective_push_to_hub = False
summary["push_to_hub"] = False
summary["push_disabled_reason"] = "missing_token"
append_log(log_lines, "Push requested but no token is available. Disabling hub push for this run.")
append_log(log_lines, "Preparing local workspace.")
yield "\n".join(log_lines), "Preparing workspace", summary_text(summary)
train_file, validation_file, test_file = download_dataset(
dataset_repo_id=dataset_repo_id,
token=token,
log_lines=log_lines,
force_redownload=bool(force_redownload),
)
summary["dataset_files"] = {
"train": train_file,
"validation": validation_file,
"test": test_file,
}
yield "\n".join(log_lines), "Dataset ready", summary_text(summary)
if is_cancel_requested():
raise RuntimeError("Run cancelled by user.")
runtime_cfg = write_runtime_config(
base_model_id=base_model_id,
model_repo_id=model_repo_id,
train_file=train_file,
validation_file=validation_file,
test_file=test_file,
run_eval=bool(run_eval),
eval_k=eval_k,
eval_samples=eval_samples,
push_to_hub=effective_push_to_hub,
enforce_quality_gate=bool(enforce_quality_gate),
gate_min_pass_at_1=gate_min_pass_at_1,
gate_min_pass_at_k=gate_min_pass_at_k,
gate_min_rows=gate_min_rows,
)
summary["runtime_config"] = str(runtime_cfg)
append_log(log_lines, f"Wrote runtime config: {runtime_cfg}")
yield "\n".join(log_lines), "Config ready", summary_text(summary)
env = os.environ.copy()
if token:
env["HF_TOKEN"] = token
env["HUGGINGFACE_HUB_TOKEN"] = token
else:
env.pop("HF_TOKEN", None)
env.pop("HUGGINGFACE_HUB_TOKEN", None)
env["PYTHONUNBUFFERED"] = "1"
train_cmd = [
sys.executable,
str(TRAIN_SCRIPT),
"--config",
str(runtime_cfg),
"--start-stage",
str(stage_start),
"--max-stages",
str(stage_count),
]
if preflight_only:
train_cmd.append("--dry-run")
append_log(log_lines, "Validation mode enabled: running dry validation without full training.")
train_gen = stream_subprocess(
cmd=train_cmd,
env=env,
cwd=ROOT,
log_lines=log_lines,
status_prefix="Training",
)
train_ret = None
while True:
try:
logs_text, status_text = next(train_gen)
summary["status"] = status_text
yield logs_text, status_text, summary_text(summary)
except StopIteration as stop:
train_ret = stop.value
break
if is_cancel_requested():
summary["result"] = "cancelled"
summary["finished_at_utc"] = now_ts()
append_log(log_lines, "Run cancelled by user.")
yield "\n".join(log_lines), "Cancelled", summary_text(summary)
return
if train_ret != 0:
summary["result"] = "failed"
summary["failure_stage"] = "training"
summary["finished_at_utc"] = now_ts()
yield "\n".join(log_lines), "Failed", summary_text(summary)
return
if preflight_only:
summary["result"] = "preflight_passed"
summary["finished_at_utc"] = now_ts()
append_log(log_lines, "Validation mode completed successfully.")
yield "\n".join(log_lines), "Preflight complete", summary_text(summary)
return
training_summary_path = TRAIN_OUTPUT_DIR / "training_summary.json"
training_summary: Optional[Dict[str, Any]] = None
if training_summary_path.exists():
try:
summary["training_summary_path"] = str(training_summary_path)
loaded_summary = json.loads(training_summary_path.read_text(encoding="utf-8"))
if isinstance(loaded_summary, dict):
training_summary = loaded_summary
summary["training_summary"] = loaded_summary
else:
summary["training_summary"] = {"warning": "Training summary JSON is not an object."}
except json.JSONDecodeError:
summary["training_summary_path"] = str(training_summary_path)
summary["training_summary"] = {"warning": "Unable to parse training summary JSON."}
if isinstance(training_summary, dict):
quality_gate = training_summary.get("quality_gate")
if isinstance(quality_gate, dict):
summary["quality_gate"] = quality_gate
append_log(
log_lines,
f"Quality gate: passed={quality_gate.get('passed')} enabled={quality_gate.get('enabled')}",
)
push_report = training_summary.get("push")
if isinstance(push_report, dict):
summary["push"] = push_report
append_log(
log_lines,
f"Push decision: requested={push_report.get('requested')} performed={push_report.get('performed')}",
)
post_eval_report = training_summary.get("post_eval")
if run_eval and isinstance(post_eval_report, dict):
summary["evaluation"] = {
"source": "train_post_eval",
"evaluated_rows": post_eval_report.get("evaluated_rows"),
"pass_at_1": post_eval_report.get("pass_at_1"),
"pass_at_k": post_eval_report.get("pass_at_k"),
"exact_at_k": post_eval_report.get("exact_at_k"),
"composite_score": post_eval_report.get("composite_score"),
"k": post_eval_report.get("k"),
"report_path": post_eval_report.get("report_path"),
}
append_log(log_lines, "Using post-eval metrics emitted by training run.")
if run_eval and "evaluation" not in summary:
eval_report = WORKSPACE_DIR / "runs" / "latest_eval_report.json"
eval_cmd = [
sys.executable,
str(EVAL_SCRIPT),
"--config",
str(runtime_cfg),
"--base-model",
base_model_id,
"--adapter-path",
str(TRAIN_OUTPUT_DIR / "final_adapter"),
"--eval-file",
str(DATA_DIR / "test.parquet"),
"--k",
str(eval_k),
"--max-samples",
str(eval_samples),
"--output-json",
str(eval_report),
]
eval_gen = stream_subprocess(
cmd=eval_cmd,
env=env,
cwd=ROOT,
log_lines=log_lines,
status_prefix="Evaluation",
)
eval_ret = None
while True:
try:
logs_text, status_text = next(eval_gen)
summary["status"] = status_text
yield logs_text, status_text, summary_text(summary)
except StopIteration as stop:
eval_ret = stop.value
break
if is_cancel_requested():
summary["result"] = "cancelled"
summary["finished_at_utc"] = now_ts()
append_log(log_lines, "Run cancelled by user.")
yield "\n".join(log_lines), "Cancelled", summary_text(summary)
return
if eval_ret != 0:
summary["result"] = "failed"
summary["failure_stage"] = "evaluation"
summary["finished_at_utc"] = now_ts()
yield "\n".join(log_lines), "Failed", summary_text(summary)
return
if eval_report.exists():
report = json.loads(eval_report.read_text(encoding="utf-8"))
summary["evaluation"] = {
"source": "fallback_eval",
"evaluated_rows": report.get("evaluated_rows"),
"pass_at_1": report.get("pass_at_1"),
"pass_at_k": report.get("pass_at_k"),
"exact_at_k": report.get("exact_at_k"),
"composite_score": report.get("composite_score"),
"k": report.get("k"),
"report_path": str(eval_report),
}
append_log(log_lines, f"Eval summary: {json.dumps(summary['evaluation'])}")
summary["result"] = "completed"
summary["finished_at_utc"] = now_ts()
append_log(log_lines, "Pipeline completed.")
yield "\n".join(log_lines), "Completed", summary_text(summary)
except Exception as exc:
cancelled = is_cancel_requested() or str(exc) == "Run cancelled by user."
summary["result"] = "cancelled" if cancelled else "failed"
summary["error"] = {"type": type(exc).__name__, "message": str(exc)}
summary["finished_at_utc"] = now_ts()
append_log(
log_lines,
f"Pipeline {'cancelled' if cancelled else 'failed'}: {type(exc).__name__}: {exc}",
)
yield "\n".join(log_lines), "Cancelled" if cancelled else "Failed", summary_text(summary)
finally:
finish_run()
def run_pipeline(
dataset_repo_id: str,
model_repo_id: str,
base_model_id: str,
autonomous_mode: bool,
start_stage: int,
max_stages: int,
run_eval: bool,
eval_k: int,
eval_samples: int,
enforce_quality_gate: bool,
gate_min_pass_at_1: float,
gate_min_pass_at_k: float,
gate_min_rows: int,
push_to_hub: bool,
force_redownload: bool,
preflight_only: bool,
) -> Generator[Tuple[str, str, str], None, None]:
pipeline = run_pipeline_core(
dataset_repo_id=dataset_repo_id,
model_repo_id=model_repo_id,
base_model_id=base_model_id,
autonomous_mode=autonomous_mode,
start_stage=start_stage,
max_stages=max_stages,
run_eval=run_eval,
eval_k=eval_k,
eval_samples=eval_samples,
enforce_quality_gate=enforce_quality_gate,
gate_min_pass_at_1=gate_min_pass_at_1,
gate_min_pass_at_k=gate_min_pass_at_k,
gate_min_rows=gate_min_rows,
push_to_hub=push_to_hub,
force_redownload=force_redownload,
preflight_only=preflight_only,
)
for logs_text, status_text, summary_json in pipeline:
summary = _parse_summary_json(summary_json)
console_text = compose_ops_console(logs_text, summary_json)
yield console_text, status_text, render_ops_visual(summary, status_text, logs_text)
with gr.Blocks(title="Math Conjecture Trainer Space") as demo:
gr.HTML(TACTICAL_HEADER_HTML)
gr.Markdown(PROJECT_DESCRIPTION)
with gr.Row():
dataset_repo_id = gr.Textbox(
label="Dataset Source",
value="NorthernTribe-Research/math-conjecture-training-corpus",
)
with gr.Row():
model_repo_id = gr.Textbox(
label="Model Destination",
value="NorthernTribe-Research/math-conjecture-model",
)
base_model_id = gr.Textbox(
label="Base Model ID",
value="deepseek-ai/deepseek-math-v2",
)
with gr.Row():
autonomous_mode = gr.Checkbox(label="Autonomous Mode", value=True)
with gr.Row():
start_stage = gr.Slider(label="Stage Start", minimum=1, maximum=TEMPLATE_STAGE_COUNT, step=1, value=1)
max_stages = gr.Slider(
label="Stage Count",
minimum=1,
maximum=TEMPLATE_STAGE_COUNT,
step=1,
value=TEMPLATE_STAGE_COUNT,
)
run_eval = gr.Checkbox(label="Run Evaluation After Training", value=True)
with gr.Row():
eval_k = gr.Slider(label="Evaluation K", minimum=1, maximum=8, step=1, value=4)
eval_samples = gr.Slider(label="Evaluation Max Samples", minimum=50, maximum=1000, step=50, value=300)
with gr.Row():
enforce_quality_gate = gr.Checkbox(label="Enforce Quality Gate", value=DEFAULT_GATE_ENABLED)
gate_min_pass_at_1 = gr.Slider(
label="Gate Min pass@1",
minimum=0.0,
maximum=0.5,
step=0.005,
value=min(max(DEFAULT_GATE_MIN_PASS_AT_1, 0.0), 0.5),
)
gate_min_pass_at_k = gr.Slider(
label="Gate Min pass@k",
minimum=0.0,
maximum=1.0,
step=0.01,
value=min(max(DEFAULT_GATE_MIN_PASS_AT_K, 0.0), 1.0),
)
gate_min_rows = gr.Slider(
label="Gate Min Rows",
minimum=10,
maximum=2000,
step=10,
value=min(max(DEFAULT_GATE_MIN_ROWS, 10), 2000),
)
with gr.Row():
push_to_hub = gr.Checkbox(label="Push Adapter to Hub", value=True)
force_redownload = gr.Checkbox(label="Force Dataset Redownload", value=False)
preflight_only = gr.Checkbox(label="Validation Mode (No Training)", value=False)
with gr.Row():
run_button = gr.Button("Execute Training Run", variant="primary")
stop_button = gr.Button("Abort Active Run", variant="stop")
clear_button = gr.Button("Reset Console")
ops_visual = gr.HTML(value=render_ops_visual({}, "Idle", ""))
status = gr.Textbox(label="Run Status", value="Idle", interactive=False)
logs = make_copyable_textbox(
label="Ops Console (Live Log + Mission JSON)",
lines=26,
max_lines=36,
interactive=False,
)
run_button.click(
fn=run_pipeline,
inputs=[
dataset_repo_id,
model_repo_id,
base_model_id,
autonomous_mode,
start_stage,
max_stages,
run_eval,
eval_k,
eval_samples,
enforce_quality_gate,
gate_min_pass_at_1,
gate_min_pass_at_k,
gate_min_rows,
push_to_hub,
force_redownload,
preflight_only,
],
outputs=[logs, status, ops_visual],
)
stop_button.click(fn=cancel_pipeline, inputs=None, outputs=[status], queue=False)
clear_button.click(fn=clear_outputs, inputs=None, outputs=[logs, status, ops_visual], queue=False)
gr.HTML(TACTICAL_FOOTER_HTML)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch(css=TACTICAL_CSS)