Spaces:
Sleeping
Sleeping
| """ | |
| app.py β Rewrite Training Space | |
| Runs train_and_upgrade.py in the background, streams logs live, | |
| and exports the upgraded model to morpheuslord/rewrite on completion. | |
| """ | |
| # ββ Patch 1: Jinja2 LRU cache βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Jinja2 >= 3.1.4 puts environment globals (a dict) into the LRU cache key, | |
| # which is unhashable. Convert any unhashable element to a hashable equivalent. | |
| def _patch_jinja2_lru_cache(): | |
| import jinja2.utils as _ju | |
| def _make_hashable(obj): | |
| if isinstance(obj, dict): | |
| return frozenset((_make_hashable(k), _make_hashable(v)) for k, v in obj.items()) | |
| if isinstance(obj, (list, tuple)): | |
| return tuple(_make_hashable(i) for i in obj) # always tuple, never list | |
| return obj | |
| _orig_gi = _ju.LRUCache.__getitem__ | |
| _orig_si = _ju.LRUCache.__setitem__ | |
| _orig_get = _ju.LRUCache.get | |
| def _gi(self, key): | |
| try: | |
| return _orig_gi(self, key) | |
| except TypeError: | |
| return _orig_gi(self, _make_hashable(key)) | |
| def _si(self, key, value): | |
| try: | |
| _orig_si(self, key, value) | |
| except TypeError: | |
| _orig_si(self, _make_hashable(key), value) | |
| def _get(self, key, default=None): | |
| try: | |
| return _orig_get(self, key, default) | |
| except TypeError: | |
| return _orig_get(self, _make_hashable(key), default) | |
| _ju.LRUCache.__getitem__ = _gi | |
| _ju.LRUCache.__setitem__ = _si | |
| _ju.LRUCache.get = _get | |
| _patch_jinja2_lru_cache() | |
| del _patch_jinja2_lru_cache | |
| # ββ Patch 2: Starlette TemplateResponse API change ββββββββββββββββββββββββββββ | |
| # Gradio 4.44.0 calls: templates.TemplateResponse(name: str, context: dict) | |
| # Newer Starlette (0.29+) changed the signature to: | |
| # TemplateResponse(request, name: str, context: dict) | |
| # so the context dict ends up as `name`, causing 'dict has no attribute split'. | |
| # Detect the old calling convention by checking if args[0] is a string, and | |
| # reorder arguments to match what newer Starlette expects. | |
| def _patch_starlette_template_response(): | |
| import starlette.templating as _st | |
| _orig = _st.Jinja2Templates.TemplateResponse | |
| def _compat(self, *args, **kwargs): | |
| # Old API: (name: str, context: dict, ...) | |
| # New API: (request, name: str, context: dict, ...) | |
| if args and isinstance(args[0], str): | |
| name = args[0] | |
| context = args[1] if len(args) > 1 else {} | |
| request = context.get("request") | |
| return _orig(self, request, name, context, *args[2:], **kwargs) | |
| return _orig(self, *args, **kwargs) | |
| _st.Jinja2Templates.TemplateResponse = _compat | |
| _patch_starlette_template_response() | |
| del _patch_starlette_template_response | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ Compatibility patch βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HF Spaces forces gradio 4.44.0 which imports HfFolder from huggingface_hub. | |
| # huggingface_hub >= 0.30 removed HfFolder. Patch it back before gradio loads. | |
| try: | |
| from huggingface_hub import HfFolder # noqa: F401 -- check if it exists | |
| except ImportError: | |
| import os as _os | |
| import huggingface_hub as _hfhub | |
| from huggingface_hub import constants as _hfconst | |
| class HfFolder: | |
| path_token = _hfconst.HF_TOKEN_PATH | |
| def get_token(cls): | |
| env = _os.environ.get("HF_TOKEN") or _os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| if env: | |
| return env | |
| try: | |
| with open(cls.path_token) as f: | |
| return f.read().strip() or None | |
| except Exception: | |
| return None | |
| def save_token(cls, token): | |
| _os.makedirs(_os.path.dirname(cls.path_token), exist_ok=True) | |
| with open(cls.path_token, "w") as f: | |
| f.write(token) | |
| def delete_token(cls): | |
| try: | |
| _os.remove(cls.path_token) | |
| except FileNotFoundError: | |
| pass | |
| _hfhub.HfFolder = HfFolder | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import gradio as gr | |
| import subprocess | |
| import threading | |
| import os | |
| import json | |
| from pathlib import Path | |
| from datetime import datetime | |
| # ββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| training_process = None | |
| log_lines = [] | |
| is_training = False | |
| last_scores = {} | |
| def get_status_text(): | |
| if is_training: | |
| return "π‘ Training in progress..." | |
| if last_scores: | |
| return ( | |
| f"β Last run complete β " | |
| f"Composite: {last_scores.get('composite', 0):.4f} | " | |
| f"GLEU: {last_scores.get('gleu', 0):.4f} | " | |
| f"BERTScore: {last_scores.get('bert_f1', 0):.4f}" | |
| ) | |
| return "βͺ Ready β no training run yet." | |
| # ββ Training thread ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _training_thread(): | |
| global training_process, log_lines, is_training, last_scores | |
| log_lines = [f"[{datetime.now().strftime('%H:%M:%S')}] Starting training pipeline..."] | |
| is_training = True | |
| try: | |
| training_process = subprocess.Popen( | |
| ["python", "train_and_upgrade.py"], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| ) | |
| for line in training_process.stdout: | |
| line = line.rstrip() | |
| if line: | |
| log_lines.append(f"[{datetime.now().strftime('%H:%M:%S')}] {line}") | |
| # Keep last 500 lines to avoid memory bloat | |
| if len(log_lines) > 500: | |
| log_lines = log_lines[-500:] | |
| training_process.wait() | |
| rc = training_process.returncode | |
| if rc == 0: | |
| log_lines.append("β Training pipeline finished successfully.") | |
| # Try to read saved baseline scores | |
| if Path("baseline_score.json").exists(): | |
| with open("baseline_score.json") as f: | |
| last_scores = json.load(f) | |
| else: | |
| log_lines.append(f"β Training process exited with code {rc}.") | |
| except Exception as e: | |
| log_lines.append(f"β Error: {e}") | |
| finally: | |
| is_training = False | |
| def start_training(): | |
| global is_training | |
| if is_training: | |
| return "β οΈ Training already running. Wait for it to finish." | |
| if not os.environ.get("HF_TOKEN"): | |
| return "β HF_TOKEN secret not set. Go to Space Settings β Secrets and add it." | |
| thread = threading.Thread(target=_training_thread, daemon=True) | |
| thread.start() | |
| return "π Training started! Click 'Refresh Logs' every few minutes to see progress." | |
| def get_logs(): | |
| if not log_lines: | |
| return "No logs yet. Start training first." | |
| return "\n".join(log_lines[-100:]) # Show last 100 lines | |
| def get_status(): | |
| return get_status_text() | |
| # ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Rewrite β Training Space") as demo: | |
| gr.Markdown(""" | |
| # π§ Rewrite β Model Training & Upgrade Space | |
| This Space trains an upgraded version of | |
| [morpheuslord/rewrite](https://huggingface.co/morpheuslord/rewrite) | |
| and pushes it back to the model repo **only if it beats the previous score**. | |
| ### What the upgrade does | |
| - LoRA rank: **r=8 β r=16** (warm-started from existing adapter) | |
| - Epochs: **5 β 10** | |
| - Loss: **CE only β CE + Style + Semantic** (the combined loss that was designed | |
| but never wired into the original trainer) | |
| - Effective batch size: **32 β 64** | |
| - Evaluation: **GLEU + BERTScore F1 + (1 - WER)** composite gate | |
| ### Before starting | |
| Make sure `HF_TOKEN` is set in **Space Settings β Secrets** with write access | |
| to `morpheuslord/rewrite`. | |
| > β οΈ **CPU Basic tier**: Training will take 12β24 hours. | |
| > For faster results, run `train_and_upgrade.py` locally on your GPU. | |
| """) | |
| status_box = gr.Textbox( | |
| label="Status", | |
| value=get_status_text(), | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| start_btn = gr.Button("π Start Training", variant="primary", scale=2) | |
| refresh_btn = gr.Button("π Refresh Logs", variant="secondary", scale=1) | |
| status_btn = gr.Button("π Refresh Status", variant="secondary", scale=1) | |
| log_box = gr.Textbox( | |
| label="Training Logs (last 100 lines)", | |
| lines=25, | |
| interactive=False, | |
| placeholder="Logs will appear here. Click 'Refresh Logs' to update.", | |
| ) | |
| gr.Markdown(""" | |
| ### Output | |
| On success, the model repo will be updated with: | |
| - The new LoRA adapter (main branch) | |
| - The merged full model weights | |
| - A commit message showing all metric scores | |
| """) | |
| start_btn.click(fn=start_training, outputs=status_box) | |
| refresh_btn.click(fn=get_logs, outputs=log_box) | |
| status_btn.click(fn=get_status, outputs=status_box) | |
| demo.launch(server_name="0.0.0.0") | |