""" Gradio UI for the HF training Space. Training does NOT start automatically — user must click "Start Training". Unsloth is pre-installed in the Docker image. """ import logging import os import sys import tempfile import threading import time import traceback from pathlib import Path def _configure_writable_runtime_dirs(): """HF Spaces can run with HOME=/ and /app read-only; use a verified cache root.""" candidates = [] if os.getenv("SPACE_RUNTIME_DIR"): candidates.append(Path(os.environ["SPACE_RUNTIME_DIR"])) uid = getattr(os, "getuid", lambda: "user")() candidates.extend( [ Path(tempfile.gettempdir()) / f"context-corruption-training-{uid}", Path("/dev/shm") / f"context-corruption-training-{uid}", ] ) last_error = None for runtime_root in candidates: try: _apply_runtime_dirs(runtime_root) return runtime_root except OSError as exc: last_error = exc runtime_root = Path(tempfile.mkdtemp(prefix="context-corruption-training-")) try: _apply_runtime_dirs(runtime_root) except OSError as exc: raise RuntimeError( f"Could not create writable runtime directories; last error: {last_error}" ) from exc return runtime_root def _apply_runtime_dirs(runtime_root: Path): cache_root = runtime_root / "cache" env_dirs = { "HOME": runtime_root, "XDG_CACHE_HOME": cache_root, "HF_HOME": cache_root / "huggingface", "HF_HUB_CACHE": cache_root / "huggingface" / "hub", "TRANSFORMERS_CACHE": cache_root / "huggingface" / "transformers", "WANDB_DIR": runtime_root / "wandb", "WANDB_CACHE_DIR": cache_root / "wandb", "MPLCONFIGDIR": cache_root / "matplotlib", "OUTPUT_DIR": runtime_root / "checkpoints" / "grpo-qwen-1.5b", } for path in env_dirs.values(): path.mkdir(parents=True, exist_ok=True) probe = runtime_root / ".write-test" probe.write_text("ok") probe.unlink(missing_ok=True) for key, path in env_dirs.items(): os.environ[key] = str(path) _runtime_root = _configure_writable_runtime_dirs() import gradio as gr from dotenv import load_dotenv load_dotenv() _log_lines: list[str] = [] _training_status = "idle" # idle | installing | running | complete | failed def _append_log(msg: str): ts = time.strftime("%H:%M:%S") _log_lines.append(f"[{ts}] {msg}") class _Tee: def __init__(self, orig): self._orig = orig def write(self, msg): if msg.strip(): _append_log(msg.rstrip()) self._orig.write(msg) def flush(self): self._orig.flush() class _LogHandler(logging.Handler): """Captures Python logging (used by transformers/TRL) into the UI log.""" def emit(self, record): try: _append_log(self.format(record)) except Exception: pass def _attach_log_capture(): handler = _LogHandler() handler.setFormatter(logging.Formatter("%(name)s | %(message)s")) handler.setLevel(logging.INFO) for name in ("transformers", "trl", "unsloth", "peft", "accelerate", "datasets", ""): lg = logging.getLogger(name) lg.addHandler(handler) lg.setLevel(logging.INFO) def _patch_torchao_import_compat(): """Patch TorchAO import probes that expect newer/nightly torch symbols. Training uses bitsandbytes 4-bit loading through Unsloth, not TorchAO quantization. These aliases are only to let optional TorchAO modules import. """ import torch try: import torch.utils._pytree as _pytree if not hasattr(_pytree, "register_constant"): _pytree.register_constant = lambda cls: cls _append_log("Applied register_constant shim for torchao compat.") except Exception as _e: _append_log(f"Warning: could not patch _pytree: {_e}") patched_dtypes = [] dtype_aliases = { **{f"int{i}": torch.int8 for i in range(1, 8)}, **{f"uint{i}": torch.uint8 for i in range(1, 8)}, } for name, fallback in dtype_aliases.items(): if not hasattr(torch, name): setattr(torch, name, fallback) patched_dtypes.append(name) if patched_dtypes: _append_log( "Applied torch dtype shims for torchao import: " + ", ".join(patched_dtypes) ) def _run_training(): global _training_status _training_status = "running" _append_log("Thread started — patching torch/torchao compat then importing unsloth...") try: _patch_torchao_import_compat() except Exception as _e: _append_log(f"Warning: could not patch torchao import compat: {_e}") try: import unsloth # noqa: F401 _append_log(f"✅ unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})") except BaseException as e: _training_status = "failed" _append_log(f"❌ unsloth import failed: {e}\n{traceback.format_exc()}") return _append_log("Attaching log capture for transformers/TRL...") _attach_log_capture() _append_log("Importing training script (loads transformers/TRL — another 30–60s)...") old_out, old_err = sys.stdout, sys.stderr sys.stdout = _Tee(old_out) sys.stderr = _Tee(old_err) try: sys.path.insert(0, str(Path(__file__).parent.parent)) from training.train_grpo import main _append_log("✅ Training script loaded — starting main()...") main() _training_status = "complete" _append_log("✅ Training complete! Model pushed to HF Hub. Check WandB for curves.") except BaseException as e: _training_status = "failed" _append_log(f"❌ Training failed: {e}\n{traceback.format_exc()[-2000:]}") finally: sys.stdout = old_out sys.stderr = old_err def start_training(): global _training_status if _training_status in ("running",): return "⚠️ Already in progress.", _get_logs() if _training_status == "complete": return "✅ Training already complete.", _get_logs() missing = [k for k in ("WANDB_API_KEY", "HF_TOKEN", "HF_HUB_MODEL_ID") if not os.getenv(k)] if missing: return f"❌ Missing secrets: {', '.join(missing)}", _get_logs() import torch if not torch.cuda.is_available(): return "❌ No GPU detected. Upgrade Space hardware to A100 first.", _get_logs() gpu = torch.cuda.get_device_name(0) _append_log(f"GPU detected: {gpu}") threading.Thread(target=_run_training, daemon=True).start() return f"🚀 Started on {gpu}. Training beginning...", _get_logs() def _get_logs() -> str: return "\n".join(_log_lines[-100:]) if _log_lines else "No logs yet." def get_status() -> str: icons = { "idle": "⏸️ Idle — ready to start", "installing": "⚙️ Installing unsloth on GPU...", "running": "🔄 Training in progress...", "complete": "✅ Training complete", "failed": "❌ Failed — check logs", } return icons.get(_training_status, _training_status) def refresh(): return get_status(), _get_logs() # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks(title="ContextCorruption Training") as demo: gr.Markdown(""" # ContextCorruption-Env — GRPO Training Fine-tuning **Qwen2-1.5B-Instruct** to identify corrupted documents. **Before clicking Start, confirm:** - Space hardware is set to **A100 Large** (Settings → Space hardware) - Secrets are set: `WANDB_API_KEY` · `HF_TOKEN` · `HF_HUB_MODEL_ID` """) status_box = gr.Textbox(label="Status", value=get_status(), interactive=False) log_box = gr.Textbox(label="Live Logs", lines=25, interactive=False, value="Waiting...") msg_box = gr.Textbox(label="", interactive=False) with gr.Row(): start_btn = gr.Button("🚀 Start Training", variant="primary", scale=2) refresh_btn = gr.Button("🔄 Refresh", scale=1) gr.Markdown(""" --- **Estimated time on A100:** ~30–45 min · **Cost:** ~$3–5 from your HF credits After training, model is pushed to HF Hub and WandB has the reward/loss curves. """) timer = gr.Timer(value=10) start_btn.click(fn=start_training, outputs=[msg_box, log_box]) refresh_btn.click(fn=refresh, outputs=[status_box, log_box]) demo.load(fn=refresh, outputs=[status_box, log_box]) timer.tick(fn=refresh, outputs=[status_box, log_box]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)