Rewrite-Space / app.py
morpheuslord's picture
Add training pipeline
85c5132
"""
app.py β€” Rewrite Training Space
Runs train_and_upgrade.py in the background, streams logs live,
and exports the upgraded model to morpheuslord/rewrite on completion.
"""
# ── Patch 1: Jinja2 LRU cache ─────────────────────────────────────────────────
# Jinja2 >= 3.1.4 puts environment globals (a dict) into the LRU cache key,
# which is unhashable. Convert any unhashable element to a hashable equivalent.
def _patch_jinja2_lru_cache():
import jinja2.utils as _ju
def _make_hashable(obj):
if isinstance(obj, dict):
return frozenset((_make_hashable(k), _make_hashable(v)) for k, v in obj.items())
if isinstance(obj, (list, tuple)):
return tuple(_make_hashable(i) for i in obj) # always tuple, never list
return obj
_orig_gi = _ju.LRUCache.__getitem__
_orig_si = _ju.LRUCache.__setitem__
_orig_get = _ju.LRUCache.get
def _gi(self, key):
try:
return _orig_gi(self, key)
except TypeError:
return _orig_gi(self, _make_hashable(key))
def _si(self, key, value):
try:
_orig_si(self, key, value)
except TypeError:
_orig_si(self, _make_hashable(key), value)
def _get(self, key, default=None):
try:
return _orig_get(self, key, default)
except TypeError:
return _orig_get(self, _make_hashable(key), default)
_ju.LRUCache.__getitem__ = _gi
_ju.LRUCache.__setitem__ = _si
_ju.LRUCache.get = _get
_patch_jinja2_lru_cache()
del _patch_jinja2_lru_cache
# ── Patch 2: Starlette TemplateResponse API change ────────────────────────────
# Gradio 4.44.0 calls: templates.TemplateResponse(name: str, context: dict)
# Newer Starlette (0.29+) changed the signature to:
# TemplateResponse(request, name: str, context: dict)
# so the context dict ends up as `name`, causing 'dict has no attribute split'.
# Detect the old calling convention by checking if args[0] is a string, and
# reorder arguments to match what newer Starlette expects.
def _patch_starlette_template_response():
import starlette.templating as _st
_orig = _st.Jinja2Templates.TemplateResponse
def _compat(self, *args, **kwargs):
# Old API: (name: str, context: dict, ...)
# New API: (request, name: str, context: dict, ...)
if args and isinstance(args[0], str):
name = args[0]
context = args[1] if len(args) > 1 else {}
request = context.get("request")
return _orig(self, request, name, context, *args[2:], **kwargs)
return _orig(self, *args, **kwargs)
_st.Jinja2Templates.TemplateResponse = _compat
_patch_starlette_template_response()
del _patch_starlette_template_response
# ─────────────────────────────────────────────────────────────────────────────
# ── Compatibility patch ───────────────────────────────────────────────────────
# HF Spaces forces gradio 4.44.0 which imports HfFolder from huggingface_hub.
# huggingface_hub >= 0.30 removed HfFolder. Patch it back before gradio loads.
try:
from huggingface_hub import HfFolder # noqa: F401 -- check if it exists
except ImportError:
import os as _os
import huggingface_hub as _hfhub
from huggingface_hub import constants as _hfconst
class HfFolder:
path_token = _hfconst.HF_TOKEN_PATH
@classmethod
def get_token(cls):
env = _os.environ.get("HF_TOKEN") or _os.environ.get("HUGGING_FACE_HUB_TOKEN")
if env:
return env
try:
with open(cls.path_token) as f:
return f.read().strip() or None
except Exception:
return None
@classmethod
def save_token(cls, token):
_os.makedirs(_os.path.dirname(cls.path_token), exist_ok=True)
with open(cls.path_token, "w") as f:
f.write(token)
@classmethod
def delete_token(cls):
try:
_os.remove(cls.path_token)
except FileNotFoundError:
pass
_hfhub.HfFolder = HfFolder
# ─────────────────────────────────────────────────────────────────────────────
import gradio as gr
import subprocess
import threading
import os
import json
from pathlib import Path
from datetime import datetime
# ── State ─────────────────────────────────────────────────────────────────────
training_process = None
log_lines = []
is_training = False
last_scores = {}
def get_status_text():
if is_training:
return "🟑 Training in progress..."
if last_scores:
return (
f"βœ… Last run complete β€” "
f"Composite: {last_scores.get('composite', 0):.4f} | "
f"GLEU: {last_scores.get('gleu', 0):.4f} | "
f"BERTScore: {last_scores.get('bert_f1', 0):.4f}"
)
return "βšͺ Ready β€” no training run yet."
# ── Training thread ────────────────────────────────────────────────────────────
def _training_thread():
global training_process, log_lines, is_training, last_scores
log_lines = [f"[{datetime.now().strftime('%H:%M:%S')}] Starting training pipeline..."]
is_training = True
try:
training_process = subprocess.Popen(
["python", "train_and_upgrade.py"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
for line in training_process.stdout:
line = line.rstrip()
if line:
log_lines.append(f"[{datetime.now().strftime('%H:%M:%S')}] {line}")
# Keep last 500 lines to avoid memory bloat
if len(log_lines) > 500:
log_lines = log_lines[-500:]
training_process.wait()
rc = training_process.returncode
if rc == 0:
log_lines.append("βœ… Training pipeline finished successfully.")
# Try to read saved baseline scores
if Path("baseline_score.json").exists():
with open("baseline_score.json") as f:
last_scores = json.load(f)
else:
log_lines.append(f"❌ Training process exited with code {rc}.")
except Exception as e:
log_lines.append(f"❌ Error: {e}")
finally:
is_training = False
def start_training():
global is_training
if is_training:
return "⚠️ Training already running. Wait for it to finish."
if not os.environ.get("HF_TOKEN"):
return "❌ HF_TOKEN secret not set. Go to Space Settings β†’ Secrets and add it."
thread = threading.Thread(target=_training_thread, daemon=True)
thread.start()
return "πŸš€ Training started! Click 'Refresh Logs' every few minutes to see progress."
def get_logs():
if not log_lines:
return "No logs yet. Start training first."
return "\n".join(log_lines[-100:]) # Show last 100 lines
def get_status():
return get_status_text()
# ── UI ─────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="Rewrite β€” Training Space") as demo:
gr.Markdown("""
# 🧠 Rewrite β€” Model Training & Upgrade Space
This Space trains an upgraded version of
[morpheuslord/rewrite](https://huggingface.co/morpheuslord/rewrite)
and pushes it back to the model repo **only if it beats the previous score**.
### What the upgrade does
- LoRA rank: **r=8 β†’ r=16** (warm-started from existing adapter)
- Epochs: **5 β†’ 10**
- Loss: **CE only β†’ CE + Style + Semantic** (the combined loss that was designed
but never wired into the original trainer)
- Effective batch size: **32 β†’ 64**
- Evaluation: **GLEU + BERTScore F1 + (1 - WER)** composite gate
### Before starting
Make sure `HF_TOKEN` is set in **Space Settings β†’ Secrets** with write access
to `morpheuslord/rewrite`.
> ⚠️ **CPU Basic tier**: Training will take 12–24 hours.
> For faster results, run `train_and_upgrade.py` locally on your GPU.
""")
status_box = gr.Textbox(
label="Status",
value=get_status_text(),
interactive=False,
)
with gr.Row():
start_btn = gr.Button("πŸš€ Start Training", variant="primary", scale=2)
refresh_btn = gr.Button("πŸ”„ Refresh Logs", variant="secondary", scale=1)
status_btn = gr.Button("πŸ“Š Refresh Status", variant="secondary", scale=1)
log_box = gr.Textbox(
label="Training Logs (last 100 lines)",
lines=25,
interactive=False,
placeholder="Logs will appear here. Click 'Refresh Logs' to update.",
)
gr.Markdown("""
### Output
On success, the model repo will be updated with:
- The new LoRA adapter (main branch)
- The merged full model weights
- A commit message showing all metric scores
""")
start_btn.click(fn=start_training, outputs=status_box)
refresh_btn.click(fn=get_logs, outputs=log_box)
status_btn.click(fn=get_status, outputs=status_box)
demo.launch(server_name="0.0.0.0")