""" app.py — Rewrite Training Space Runs train_and_upgrade.py in the background, streams logs live, and exports the upgraded model to morpheuslord/rewrite on completion. """ # ── Patch 1: Jinja2 LRU cache ───────────────────────────────────────────────── # Jinja2 >= 3.1.4 puts environment globals (a dict) into the LRU cache key, # which is unhashable. Convert any unhashable element to a hashable equivalent. def _patch_jinja2_lru_cache(): import jinja2.utils as _ju def _make_hashable(obj): if isinstance(obj, dict): return frozenset((_make_hashable(k), _make_hashable(v)) for k, v in obj.items()) if isinstance(obj, (list, tuple)): return tuple(_make_hashable(i) for i in obj) # always tuple, never list return obj _orig_gi = _ju.LRUCache.__getitem__ _orig_si = _ju.LRUCache.__setitem__ _orig_get = _ju.LRUCache.get def _gi(self, key): try: return _orig_gi(self, key) except TypeError: return _orig_gi(self, _make_hashable(key)) def _si(self, key, value): try: _orig_si(self, key, value) except TypeError: _orig_si(self, _make_hashable(key), value) def _get(self, key, default=None): try: return _orig_get(self, key, default) except TypeError: return _orig_get(self, _make_hashable(key), default) _ju.LRUCache.__getitem__ = _gi _ju.LRUCache.__setitem__ = _si _ju.LRUCache.get = _get _patch_jinja2_lru_cache() del _patch_jinja2_lru_cache # ── Patch 2: Starlette TemplateResponse API change ──────────────────────────── # Gradio 4.44.0 calls: templates.TemplateResponse(name: str, context: dict) # Newer Starlette (0.29+) changed the signature to: # TemplateResponse(request, name: str, context: dict) # so the context dict ends up as `name`, causing 'dict has no attribute split'. # Detect the old calling convention by checking if args[0] is a string, and # reorder arguments to match what newer Starlette expects. def _patch_starlette_template_response(): import starlette.templating as _st _orig = _st.Jinja2Templates.TemplateResponse def _compat(self, *args, **kwargs): # Old API: (name: str, context: dict, ...) # New API: (request, name: str, context: dict, ...) if args and isinstance(args[0], str): name = args[0] context = args[1] if len(args) > 1 else {} request = context.get("request") return _orig(self, request, name, context, *args[2:], **kwargs) return _orig(self, *args, **kwargs) _st.Jinja2Templates.TemplateResponse = _compat _patch_starlette_template_response() del _patch_starlette_template_response # ───────────────────────────────────────────────────────────────────────────── # ── Compatibility patch ─────────────────────────────────────────────────────── # HF Spaces forces gradio 4.44.0 which imports HfFolder from huggingface_hub. # huggingface_hub >= 0.30 removed HfFolder. Patch it back before gradio loads. try: from huggingface_hub import HfFolder # noqa: F401 -- check if it exists except ImportError: import os as _os import huggingface_hub as _hfhub from huggingface_hub import constants as _hfconst class HfFolder: path_token = _hfconst.HF_TOKEN_PATH @classmethod def get_token(cls): env = _os.environ.get("HF_TOKEN") or _os.environ.get("HUGGING_FACE_HUB_TOKEN") if env: return env try: with open(cls.path_token) as f: return f.read().strip() or None except Exception: return None @classmethod def save_token(cls, token): _os.makedirs(_os.path.dirname(cls.path_token), exist_ok=True) with open(cls.path_token, "w") as f: f.write(token) @classmethod def delete_token(cls): try: _os.remove(cls.path_token) except FileNotFoundError: pass _hfhub.HfFolder = HfFolder # ───────────────────────────────────────────────────────────────────────────── import gradio as gr import subprocess import threading import os import json from pathlib import Path from datetime import datetime # ── State ───────────────────────────────────────────────────────────────────── training_process = None log_lines = [] is_training = False last_scores = {} def get_status_text(): if is_training: return "🟡 Training in progress..." if last_scores: return ( f"✅ Last run complete — " f"Composite: {last_scores.get('composite', 0):.4f} | " f"GLEU: {last_scores.get('gleu', 0):.4f} | " f"BERTScore: {last_scores.get('bert_f1', 0):.4f}" ) return "⚪ Ready — no training run yet." # ── Training thread ──────────────────────────────────────────────────────────── def _training_thread(): global training_process, log_lines, is_training, last_scores log_lines = [f"[{datetime.now().strftime('%H:%M:%S')}] Starting training pipeline..."] is_training = True try: training_process = subprocess.Popen( ["python", "train_and_upgrade.py"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, ) for line in training_process.stdout: line = line.rstrip() if line: log_lines.append(f"[{datetime.now().strftime('%H:%M:%S')}] {line}") # Keep last 500 lines to avoid memory bloat if len(log_lines) > 500: log_lines = log_lines[-500:] training_process.wait() rc = training_process.returncode if rc == 0: log_lines.append("✅ Training pipeline finished successfully.") # Try to read saved baseline scores if Path("baseline_score.json").exists(): with open("baseline_score.json") as f: last_scores = json.load(f) else: log_lines.append(f"❌ Training process exited with code {rc}.") except Exception as e: log_lines.append(f"❌ Error: {e}") finally: is_training = False def start_training(): global is_training if is_training: return "⚠️ Training already running. Wait for it to finish." if not os.environ.get("HF_TOKEN"): return "❌ HF_TOKEN secret not set. Go to Space Settings → Secrets and add it." thread = threading.Thread(target=_training_thread, daemon=True) thread.start() return "🚀 Training started! Click 'Refresh Logs' every few minutes to see progress." def get_logs(): if not log_lines: return "No logs yet. Start training first." return "\n".join(log_lines[-100:]) # Show last 100 lines def get_status(): return get_status_text() # ── UI ───────────────────────────────────────────────────────────────────────── with gr.Blocks(title="Rewrite — Training Space") as demo: gr.Markdown(""" # 🧠 Rewrite — Model Training & Upgrade Space This Space trains an upgraded version of [morpheuslord/rewrite](https://huggingface.co/morpheuslord/rewrite) and pushes it back to the model repo **only if it beats the previous score**. ### What the upgrade does - LoRA rank: **r=8 → r=16** (warm-started from existing adapter) - Epochs: **5 → 10** - Loss: **CE only → CE + Style + Semantic** (the combined loss that was designed but never wired into the original trainer) - Effective batch size: **32 → 64** - Evaluation: **GLEU + BERTScore F1 + (1 - WER)** composite gate ### Before starting Make sure `HF_TOKEN` is set in **Space Settings → Secrets** with write access to `morpheuslord/rewrite`. > ⚠️ **CPU Basic tier**: Training will take 12–24 hours. > For faster results, run `train_and_upgrade.py` locally on your GPU. """) status_box = gr.Textbox( label="Status", value=get_status_text(), interactive=False, ) with gr.Row(): start_btn = gr.Button("🚀 Start Training", variant="primary", scale=2) refresh_btn = gr.Button("🔄 Refresh Logs", variant="secondary", scale=1) status_btn = gr.Button("📊 Refresh Status", variant="secondary", scale=1) log_box = gr.Textbox( label="Training Logs (last 100 lines)", lines=25, interactive=False, placeholder="Logs will appear here. Click 'Refresh Logs' to update.", ) gr.Markdown(""" ### Output On success, the model repo will be updated with: - The new LoRA adapter (main branch) - The merged full model weights - A commit message showing all metric scores """) start_btn.click(fn=start_training, outputs=status_box) refresh_btn.click(fn=get_logs, outputs=log_box) status_btn.click(fn=get_status, outputs=status_box) demo.launch(server_name="0.0.0.0")