| """ |
| Gradio UI for the HF training Space. |
| Training does NOT start automatically β user must click "Start Training". |
| Unsloth is pre-installed in the Docker image. |
| """ |
| import logging |
| import os |
| import sys |
| import tempfile |
| import threading |
| import time |
| import traceback |
| from pathlib import Path |
|
|
|
|
| def _configure_writable_runtime_dirs(): |
| """HF Spaces can run with HOME=/ and /app read-only; use a verified cache root.""" |
| candidates = [] |
| if os.getenv("SPACE_RUNTIME_DIR"): |
| candidates.append(Path(os.environ["SPACE_RUNTIME_DIR"])) |
|
|
| uid = getattr(os, "getuid", lambda: "user")() |
| candidates.extend( |
| [ |
| Path(tempfile.gettempdir()) / f"context-corruption-training-{uid}", |
| Path("/dev/shm") / f"context-corruption-training-{uid}", |
| ] |
| ) |
|
|
| last_error = None |
| for runtime_root in candidates: |
| try: |
| _apply_runtime_dirs(runtime_root) |
| return runtime_root |
| except OSError as exc: |
| last_error = exc |
|
|
| runtime_root = Path(tempfile.mkdtemp(prefix="context-corruption-training-")) |
| try: |
| _apply_runtime_dirs(runtime_root) |
| except OSError as exc: |
| raise RuntimeError( |
| f"Could not create writable runtime directories; last error: {last_error}" |
| ) from exc |
| return runtime_root |
|
|
|
|
| def _apply_runtime_dirs(runtime_root: Path): |
| cache_root = runtime_root / "cache" |
| env_dirs = { |
| "HOME": runtime_root, |
| "XDG_CACHE_HOME": cache_root, |
| "HF_HOME": cache_root / "huggingface", |
| "HF_HUB_CACHE": cache_root / "huggingface" / "hub", |
| "TRANSFORMERS_CACHE": cache_root / "huggingface" / "transformers", |
| "WANDB_DIR": runtime_root / "wandb", |
| "WANDB_CACHE_DIR": cache_root / "wandb", |
| "MPLCONFIGDIR": cache_root / "matplotlib", |
| "OUTPUT_DIR": runtime_root / "checkpoints" / "grpo-qwen-1.5b", |
| } |
|
|
| for path in env_dirs.values(): |
| path.mkdir(parents=True, exist_ok=True) |
|
|
| probe = runtime_root / ".write-test" |
| probe.write_text("ok") |
| probe.unlink(missing_ok=True) |
|
|
| for key, path in env_dirs.items(): |
| os.environ[key] = str(path) |
|
|
|
|
| _runtime_root = _configure_writable_runtime_dirs() |
|
|
| import gradio as gr |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| _log_lines: list[str] = [] |
| _training_status = "idle" |
|
|
|
|
| def _append_log(msg: str): |
| ts = time.strftime("%H:%M:%S") |
| _log_lines.append(f"[{ts}] {msg}") |
|
|
|
|
| class _Tee: |
| def __init__(self, orig): |
| self._orig = orig |
|
|
| def write(self, msg): |
| if msg.strip(): |
| _append_log(msg.rstrip()) |
| self._orig.write(msg) |
|
|
| def flush(self): |
| self._orig.flush() |
|
|
|
|
| class _LogHandler(logging.Handler): |
| """Captures Python logging (used by transformers/TRL) into the UI log.""" |
| def emit(self, record): |
| try: |
| _append_log(self.format(record)) |
| except Exception: |
| pass |
|
|
|
|
| def _attach_log_capture(): |
| handler = _LogHandler() |
| handler.setFormatter(logging.Formatter("%(name)s | %(message)s")) |
| handler.setLevel(logging.INFO) |
| for name in ("transformers", "trl", "unsloth", "peft", "accelerate", "datasets", ""): |
| lg = logging.getLogger(name) |
| lg.addHandler(handler) |
| lg.setLevel(logging.INFO) |
|
|
|
|
| def _patch_torchao_import_compat(): |
| """Patch TorchAO import probes that expect newer/nightly torch symbols. |
| |
| Training uses bitsandbytes 4-bit loading through Unsloth, not TorchAO |
| quantization. These aliases are only to let optional TorchAO modules import. |
| """ |
| import torch |
|
|
| try: |
| import torch.utils._pytree as _pytree |
| if not hasattr(_pytree, "register_constant"): |
| _pytree.register_constant = lambda cls: cls |
| _append_log("Applied register_constant shim for torchao compat.") |
| except Exception as _e: |
| _append_log(f"Warning: could not patch _pytree: {_e}") |
|
|
| patched_dtypes = [] |
| dtype_aliases = { |
| **{f"int{i}": torch.int8 for i in range(1, 8)}, |
| **{f"uint{i}": torch.uint8 for i in range(1, 8)}, |
| } |
| for name, fallback in dtype_aliases.items(): |
| if not hasattr(torch, name): |
| setattr(torch, name, fallback) |
| patched_dtypes.append(name) |
|
|
| if patched_dtypes: |
| _append_log( |
| "Applied torch dtype shims for torchao import: " |
| + ", ".join(patched_dtypes) |
| ) |
|
|
|
|
| def _run_training(): |
| global _training_status |
| _training_status = "running" |
| _append_log("Thread started β patching torch/torchao compat then importing unsloth...") |
|
|
| try: |
| _patch_torchao_import_compat() |
| except Exception as _e: |
| _append_log(f"Warning: could not patch torchao import compat: {_e}") |
|
|
| try: |
| import unsloth |
| _append_log(f"β
unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})") |
| except BaseException as e: |
| _training_status = "failed" |
| _append_log(f"β unsloth import failed: {e}\n{traceback.format_exc()}") |
| return |
|
|
| _append_log("Attaching log capture for transformers/TRL...") |
| _attach_log_capture() |
|
|
| _append_log("Importing training script (loads transformers/TRL β another 30β60s)...") |
| old_out, old_err = sys.stdout, sys.stderr |
| sys.stdout = _Tee(old_out) |
| sys.stderr = _Tee(old_err) |
| try: |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from training.train_grpo import main |
| _append_log("β
Training script loaded β starting main()...") |
| main() |
| _training_status = "complete" |
| _append_log("β
Training complete! Model pushed to HF Hub. Check WandB for curves.") |
| except BaseException as e: |
| _training_status = "failed" |
| _append_log(f"β Training failed: {e}\n{traceback.format_exc()[-2000:]}") |
| finally: |
| sys.stdout = old_out |
| sys.stderr = old_err |
|
|
|
|
| def start_training(): |
| global _training_status |
| if _training_status in ("running",): |
| return "β οΈ Already in progress.", _get_logs() |
| if _training_status == "complete": |
| return "β
Training already complete.", _get_logs() |
|
|
| missing = [k for k in ("WANDB_API_KEY", "HF_TOKEN", "HF_HUB_MODEL_ID") |
| if not os.getenv(k)] |
| if missing: |
| return f"β Missing secrets: {', '.join(missing)}", _get_logs() |
|
|
| import torch |
| if not torch.cuda.is_available(): |
| return "β No GPU detected. Upgrade Space hardware to A100 first.", _get_logs() |
|
|
| gpu = torch.cuda.get_device_name(0) |
| _append_log(f"GPU detected: {gpu}") |
| threading.Thread(target=_run_training, daemon=True).start() |
| return f"π Started on {gpu}. Training beginning...", _get_logs() |
|
|
|
|
| def _get_logs() -> str: |
| return "\n".join(_log_lines[-100:]) if _log_lines else "No logs yet." |
|
|
|
|
| def get_status() -> str: |
| icons = { |
| "idle": "βΈοΈ Idle β ready to start", |
| "installing": "βοΈ Installing unsloth on GPU...", |
| "running": "π Training in progress...", |
| "complete": "β
Training complete", |
| "failed": "β Failed β check logs", |
| } |
| return icons.get(_training_status, _training_status) |
|
|
|
|
| def refresh(): |
| return get_status(), _get_logs() |
|
|
|
|
| |
|
|
| with gr.Blocks(title="ContextCorruption Training") as demo: |
| gr.Markdown(""" |
| # ContextCorruption-Env β GRPO Training |
| Fine-tuning **Qwen2-1.5B-Instruct** to identify corrupted documents. |
| |
| **Before clicking Start, confirm:** |
| - Space hardware is set to **A100 Large** (Settings β Space hardware) |
| - Secrets are set: `WANDB_API_KEY` Β· `HF_TOKEN` Β· `HF_HUB_MODEL_ID` |
| """) |
|
|
| status_box = gr.Textbox(label="Status", value=get_status(), interactive=False) |
| log_box = gr.Textbox(label="Live Logs", lines=25, interactive=False, value="Waiting...") |
| msg_box = gr.Textbox(label="", interactive=False) |
|
|
| with gr.Row(): |
| start_btn = gr.Button("π Start Training", variant="primary", scale=2) |
| refresh_btn = gr.Button("π Refresh", scale=1) |
|
|
| gr.Markdown(""" |
| --- |
| **Estimated time on A100:** ~30β45 min Β· **Cost:** ~$3β5 from your HF credits |
| After training, model is pushed to HF Hub and WandB has the reward/loss curves. |
| """) |
|
|
| timer = gr.Timer(value=10) |
|
|
| start_btn.click(fn=start_training, outputs=[msg_box, log_box]) |
| refresh_btn.click(fn=refresh, outputs=[status_box, log_box]) |
| demo.load(fn=refresh, outputs=[status_box, log_box]) |
| timer.tick(fn=refresh, outputs=[status_box, log_box]) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|