Spaces:
Sleeping
Sleeping
| """ | |
| Gradio UI for the HF training Space. | |
| Training does NOT start automatically β user must click "Start Training". | |
| Unsloth is pre-installed in the Docker image. | |
| """ | |
| import logging | |
| import os | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| def _configure_writable_runtime_dirs(): | |
| """HF Spaces can run with HOME=/ and /app read-only; use a verified cache root.""" | |
| candidates = [] | |
| if os.getenv("SPACE_RUNTIME_DIR"): | |
| candidates.append(Path(os.environ["SPACE_RUNTIME_DIR"])) | |
| uid = getattr(os, "getuid", lambda: "user")() | |
| candidates.extend( | |
| [ | |
| Path(tempfile.gettempdir()) / f"context-corruption-training-{uid}", | |
| Path("/dev/shm") / f"context-corruption-training-{uid}", | |
| ] | |
| ) | |
| last_error = None | |
| for runtime_root in candidates: | |
| try: | |
| _apply_runtime_dirs(runtime_root) | |
| return runtime_root | |
| except OSError as exc: | |
| last_error = exc | |
| runtime_root = Path(tempfile.mkdtemp(prefix="context-corruption-training-")) | |
| try: | |
| _apply_runtime_dirs(runtime_root) | |
| except OSError as exc: | |
| raise RuntimeError( | |
| f"Could not create writable runtime directories; last error: {last_error}" | |
| ) from exc | |
| return runtime_root | |
| def _apply_runtime_dirs(runtime_root: Path): | |
| cache_root = runtime_root / "cache" | |
| env_dirs = { | |
| "HOME": runtime_root, | |
| "XDG_CACHE_HOME": cache_root, | |
| "HF_HOME": cache_root / "huggingface", | |
| "HF_HUB_CACHE": cache_root / "huggingface" / "hub", | |
| "TRANSFORMERS_CACHE": cache_root / "huggingface" / "transformers", | |
| "WANDB_DIR": runtime_root / "wandb", | |
| "WANDB_CACHE_DIR": cache_root / "wandb", | |
| "MPLCONFIGDIR": cache_root / "matplotlib", | |
| "OUTPUT_DIR": runtime_root / "checkpoints" / "grpo-qwen-1.5b", | |
| } | |
| for path in env_dirs.values(): | |
| path.mkdir(parents=True, exist_ok=True) | |
| probe = runtime_root / ".write-test" | |
| probe.write_text("ok") | |
| probe.unlink(missing_ok=True) | |
| for key, path in env_dirs.items(): | |
| os.environ[key] = str(path) | |
| _runtime_root = _configure_writable_runtime_dirs() | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| _log_lines: list[str] = [] | |
| _training_status = "idle" # idle | installing | running | complete | failed | |
| def _append_log(msg: str): | |
| ts = time.strftime("%H:%M:%S") | |
| _log_lines.append(f"[{ts}] {msg}") | |
| class _Tee: | |
| def __init__(self, orig): | |
| self._orig = orig | |
| def write(self, msg): | |
| if msg.strip(): | |
| _append_log(msg.rstrip()) | |
| self._orig.write(msg) | |
| def flush(self): | |
| self._orig.flush() | |
| class _LogHandler(logging.Handler): | |
| """Captures Python logging (used by transformers/TRL) into the UI log.""" | |
| def emit(self, record): | |
| try: | |
| _append_log(self.format(record)) | |
| except Exception: | |
| pass | |
| def _attach_log_capture(): | |
| handler = _LogHandler() | |
| handler.setFormatter(logging.Formatter("%(name)s | %(message)s")) | |
| handler.setLevel(logging.INFO) | |
| for name in ("transformers", "trl", "unsloth", "peft", "accelerate", "datasets", ""): | |
| lg = logging.getLogger(name) | |
| lg.addHandler(handler) | |
| lg.setLevel(logging.INFO) | |
| def _patch_torchao_import_compat(): | |
| """Patch TorchAO import probes that expect newer/nightly torch symbols. | |
| Training uses bitsandbytes 4-bit loading through Unsloth, not TorchAO | |
| quantization. These aliases are only to let optional TorchAO modules import. | |
| """ | |
| import torch | |
| try: | |
| import torch.utils._pytree as _pytree | |
| if not hasattr(_pytree, "register_constant"): | |
| _pytree.register_constant = lambda cls: cls | |
| _append_log("Applied register_constant shim for torchao compat.") | |
| except Exception as _e: | |
| _append_log(f"Warning: could not patch _pytree: {_e}") | |
| patched_dtypes = [] | |
| dtype_aliases = { | |
| **{f"int{i}": torch.int8 for i in range(1, 8)}, | |
| **{f"uint{i}": torch.uint8 for i in range(1, 8)}, | |
| } | |
| for name, fallback in dtype_aliases.items(): | |
| if not hasattr(torch, name): | |
| setattr(torch, name, fallback) | |
| patched_dtypes.append(name) | |
| if patched_dtypes: | |
| _append_log( | |
| "Applied torch dtype shims for torchao import: " | |
| + ", ".join(patched_dtypes) | |
| ) | |
| def _run_training(): | |
| global _training_status | |
| _training_status = "running" | |
| _append_log("Thread started β patching torch/torchao compat then importing unsloth...") | |
| try: | |
| _patch_torchao_import_compat() | |
| except Exception as _e: | |
| _append_log(f"Warning: could not patch torchao import compat: {_e}") | |
| try: | |
| import unsloth # noqa: F401 | |
| _append_log(f"β unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})") | |
| except BaseException as e: | |
| _training_status = "failed" | |
| _append_log(f"β unsloth import failed: {e}\n{traceback.format_exc()}") | |
| return | |
| _append_log("Attaching log capture for transformers/TRL...") | |
| _attach_log_capture() | |
| _append_log("Importing training script (loads transformers/TRL β another 30β60s)...") | |
| old_out, old_err = sys.stdout, sys.stderr | |
| sys.stdout = _Tee(old_out) | |
| sys.stderr = _Tee(old_err) | |
| try: | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from training.train_grpo import main | |
| _append_log("β Training script loaded β starting main()...") | |
| main() | |
| _training_status = "complete" | |
| _append_log("β Training complete! Model pushed to HF Hub. Check WandB for curves.") | |
| except BaseException as e: | |
| _training_status = "failed" | |
| _append_log(f"β Training failed: {e}\n{traceback.format_exc()[-2000:]}") | |
| finally: | |
| sys.stdout = old_out | |
| sys.stderr = old_err | |
| def start_training(): | |
| global _training_status | |
| if _training_status in ("running",): | |
| return "β οΈ Already in progress.", _get_logs() | |
| if _training_status == "complete": | |
| return "β Training already complete.", _get_logs() | |
| missing = [k for k in ("WANDB_API_KEY", "HF_TOKEN", "HF_HUB_MODEL_ID") | |
| if not os.getenv(k)] | |
| if missing: | |
| return f"β Missing secrets: {', '.join(missing)}", _get_logs() | |
| import torch | |
| if not torch.cuda.is_available(): | |
| return "β No GPU detected. Upgrade Space hardware to A100 first.", _get_logs() | |
| gpu = torch.cuda.get_device_name(0) | |
| _append_log(f"GPU detected: {gpu}") | |
| threading.Thread(target=_run_training, daemon=True).start() | |
| return f"π Started on {gpu}. Training beginning...", _get_logs() | |
| def _get_logs() -> str: | |
| return "\n".join(_log_lines[-100:]) if _log_lines else "No logs yet." | |
| def get_status() -> str: | |
| icons = { | |
| "idle": "βΈοΈ Idle β ready to start", | |
| "installing": "βοΈ Installing unsloth on GPU...", | |
| "running": "π Training in progress...", | |
| "complete": "β Training complete", | |
| "failed": "β Failed β check logs", | |
| } | |
| return icons.get(_training_status, _training_status) | |
| def refresh(): | |
| return get_status(), _get_logs() | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="ContextCorruption Training") as demo: | |
| gr.Markdown(""" | |
| # ContextCorruption-Env β GRPO Training | |
| Fine-tuning **Qwen2-1.5B-Instruct** to identify corrupted documents. | |
| **Before clicking Start, confirm:** | |
| - Space hardware is set to **A100 Large** (Settings β Space hardware) | |
| - Secrets are set: `WANDB_API_KEY` Β· `HF_TOKEN` Β· `HF_HUB_MODEL_ID` | |
| """) | |
| status_box = gr.Textbox(label="Status", value=get_status(), interactive=False) | |
| log_box = gr.Textbox(label="Live Logs", lines=25, interactive=False, value="Waiting...") | |
| msg_box = gr.Textbox(label="", interactive=False) | |
| with gr.Row(): | |
| start_btn = gr.Button("π Start Training", variant="primary", scale=2) | |
| refresh_btn = gr.Button("π Refresh", scale=1) | |
| gr.Markdown(""" | |
| --- | |
| **Estimated time on A100:** ~30β45 min Β· **Cost:** ~$3β5 from your HF credits | |
| After training, model is pushed to HF Hub and WandB has the reward/loss curves. | |
| """) | |
| timer = gr.Timer(value=10) | |
| start_btn.click(fn=start_training, outputs=[msg_box, log_box]) | |
| refresh_btn.click(fn=refresh, outputs=[status_box, log_box]) | |
| demo.load(fn=refresh, outputs=[status_box, log_box]) | |
| timer.tick(fn=refresh, outputs=[status_box, log_box]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |