Siddh12334's picture
fix: select writable cache root at startup
bb169f7 verified
"""
Gradio UI for the HF training Space.
Training does NOT start automatically β€” user must click "Start Training".
Unsloth is pre-installed in the Docker image.
"""
import logging
import os
import sys
import tempfile
import threading
import time
import traceback
from pathlib import Path
def _configure_writable_runtime_dirs():
"""HF Spaces can run with HOME=/ and /app read-only; use a verified cache root."""
candidates = []
if os.getenv("SPACE_RUNTIME_DIR"):
candidates.append(Path(os.environ["SPACE_RUNTIME_DIR"]))
uid = getattr(os, "getuid", lambda: "user")()
candidates.extend(
[
Path(tempfile.gettempdir()) / f"context-corruption-training-{uid}",
Path("/dev/shm") / f"context-corruption-training-{uid}",
]
)
last_error = None
for runtime_root in candidates:
try:
_apply_runtime_dirs(runtime_root)
return runtime_root
except OSError as exc:
last_error = exc
runtime_root = Path(tempfile.mkdtemp(prefix="context-corruption-training-"))
try:
_apply_runtime_dirs(runtime_root)
except OSError as exc:
raise RuntimeError(
f"Could not create writable runtime directories; last error: {last_error}"
) from exc
return runtime_root
def _apply_runtime_dirs(runtime_root: Path):
cache_root = runtime_root / "cache"
env_dirs = {
"HOME": runtime_root,
"XDG_CACHE_HOME": cache_root,
"HF_HOME": cache_root / "huggingface",
"HF_HUB_CACHE": cache_root / "huggingface" / "hub",
"TRANSFORMERS_CACHE": cache_root / "huggingface" / "transformers",
"WANDB_DIR": runtime_root / "wandb",
"WANDB_CACHE_DIR": cache_root / "wandb",
"MPLCONFIGDIR": cache_root / "matplotlib",
"OUTPUT_DIR": runtime_root / "checkpoints" / "grpo-qwen-1.5b",
}
for path in env_dirs.values():
path.mkdir(parents=True, exist_ok=True)
probe = runtime_root / ".write-test"
probe.write_text("ok")
probe.unlink(missing_ok=True)
for key, path in env_dirs.items():
os.environ[key] = str(path)
_runtime_root = _configure_writable_runtime_dirs()
import gradio as gr
from dotenv import load_dotenv
load_dotenv()
_log_lines: list[str] = []
_training_status = "idle" # idle | installing | running | complete | failed
def _append_log(msg: str):
ts = time.strftime("%H:%M:%S")
_log_lines.append(f"[{ts}] {msg}")
class _Tee:
def __init__(self, orig):
self._orig = orig
def write(self, msg):
if msg.strip():
_append_log(msg.rstrip())
self._orig.write(msg)
def flush(self):
self._orig.flush()
class _LogHandler(logging.Handler):
"""Captures Python logging (used by transformers/TRL) into the UI log."""
def emit(self, record):
try:
_append_log(self.format(record))
except Exception:
pass
def _attach_log_capture():
handler = _LogHandler()
handler.setFormatter(logging.Formatter("%(name)s | %(message)s"))
handler.setLevel(logging.INFO)
for name in ("transformers", "trl", "unsloth", "peft", "accelerate", "datasets", ""):
lg = logging.getLogger(name)
lg.addHandler(handler)
lg.setLevel(logging.INFO)
def _patch_torchao_import_compat():
"""Patch TorchAO import probes that expect newer/nightly torch symbols.
Training uses bitsandbytes 4-bit loading through Unsloth, not TorchAO
quantization. These aliases are only to let optional TorchAO modules import.
"""
import torch
try:
import torch.utils._pytree as _pytree
if not hasattr(_pytree, "register_constant"):
_pytree.register_constant = lambda cls: cls
_append_log("Applied register_constant shim for torchao compat.")
except Exception as _e:
_append_log(f"Warning: could not patch _pytree: {_e}")
patched_dtypes = []
dtype_aliases = {
**{f"int{i}": torch.int8 for i in range(1, 8)},
**{f"uint{i}": torch.uint8 for i in range(1, 8)},
}
for name, fallback in dtype_aliases.items():
if not hasattr(torch, name):
setattr(torch, name, fallback)
patched_dtypes.append(name)
if patched_dtypes:
_append_log(
"Applied torch dtype shims for torchao import: "
+ ", ".join(patched_dtypes)
)
def _run_training():
global _training_status
_training_status = "running"
_append_log("Thread started β€” patching torch/torchao compat then importing unsloth...")
try:
_patch_torchao_import_compat()
except Exception as _e:
_append_log(f"Warning: could not patch torchao import compat: {_e}")
try:
import unsloth # noqa: F401
_append_log(f"βœ… unsloth ready (v{getattr(unsloth, '__version__', 'unknown')})")
except BaseException as e:
_training_status = "failed"
_append_log(f"❌ unsloth import failed: {e}\n{traceback.format_exc()}")
return
_append_log("Attaching log capture for transformers/TRL...")
_attach_log_capture()
_append_log("Importing training script (loads transformers/TRL β€” another 30–60s)...")
old_out, old_err = sys.stdout, sys.stderr
sys.stdout = _Tee(old_out)
sys.stderr = _Tee(old_err)
try:
sys.path.insert(0, str(Path(__file__).parent.parent))
from training.train_grpo import main
_append_log("βœ… Training script loaded β€” starting main()...")
main()
_training_status = "complete"
_append_log("βœ… Training complete! Model pushed to HF Hub. Check WandB for curves.")
except BaseException as e:
_training_status = "failed"
_append_log(f"❌ Training failed: {e}\n{traceback.format_exc()[-2000:]}")
finally:
sys.stdout = old_out
sys.stderr = old_err
def start_training():
global _training_status
if _training_status in ("running",):
return "⚠️ Already in progress.", _get_logs()
if _training_status == "complete":
return "βœ… Training already complete.", _get_logs()
missing = [k for k in ("WANDB_API_KEY", "HF_TOKEN", "HF_HUB_MODEL_ID")
if not os.getenv(k)]
if missing:
return f"❌ Missing secrets: {', '.join(missing)}", _get_logs()
import torch
if not torch.cuda.is_available():
return "❌ No GPU detected. Upgrade Space hardware to A100 first.", _get_logs()
gpu = torch.cuda.get_device_name(0)
_append_log(f"GPU detected: {gpu}")
threading.Thread(target=_run_training, daemon=True).start()
return f"πŸš€ Started on {gpu}. Training beginning...", _get_logs()
def _get_logs() -> str:
return "\n".join(_log_lines[-100:]) if _log_lines else "No logs yet."
def get_status() -> str:
icons = {
"idle": "⏸️ Idle β€” ready to start",
"installing": "βš™οΈ Installing unsloth on GPU...",
"running": "πŸ”„ Training in progress...",
"complete": "βœ… Training complete",
"failed": "❌ Failed β€” check logs",
}
return icons.get(_training_status, _training_status)
def refresh():
return get_status(), _get_logs()
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="ContextCorruption Training") as demo:
gr.Markdown("""
# ContextCorruption-Env β€” GRPO Training
Fine-tuning **Qwen2-1.5B-Instruct** to identify corrupted documents.
**Before clicking Start, confirm:**
- Space hardware is set to **A100 Large** (Settings β†’ Space hardware)
- Secrets are set: `WANDB_API_KEY` Β· `HF_TOKEN` Β· `HF_HUB_MODEL_ID`
""")
status_box = gr.Textbox(label="Status", value=get_status(), interactive=False)
log_box = gr.Textbox(label="Live Logs", lines=25, interactive=False, value="Waiting...")
msg_box = gr.Textbox(label="", interactive=False)
with gr.Row():
start_btn = gr.Button("πŸš€ Start Training", variant="primary", scale=2)
refresh_btn = gr.Button("πŸ”„ Refresh", scale=1)
gr.Markdown("""
---
**Estimated time on A100:** ~30–45 min Β· **Cost:** ~$3–5 from your HF credits
After training, model is pushed to HF Hub and WandB has the reward/loss curves.
""")
timer = gr.Timer(value=10)
start_btn.click(fn=start_training, outputs=[msg_box, log_box])
refresh_btn.click(fn=refresh, outputs=[status_box, log_box])
demo.load(fn=refresh, outputs=[status_box, log_box])
timer.tick(fn=refresh, outputs=[status_box, log_box])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)