"""app.py — Tweaktron: Omni-Mythos generation Space (ZeroGPU). Pulls a chosen checkpoint step from the HF model repo and runs generation. Requires modeling_mythos.py and gdn2.py to be present alongside this file in the Space repo. """ import os import sys import subprocess import torch import gradio as gr import spaces from huggingface_hub import hf_hub_download, list_repo_files from transformers import AutoTokenizer # ---------------- mamba_src setup ---------------- # Mamba3 runs on pure Triton kernels and never touches the compiled # selective_scan_cuda extension at runtime -- that extension is only used # by the legacy Mamba1/Mamba2 fallback path. `pip install mamba-ssm` tries # to build it anyway, which is fragile/slow in a Space build environment. # Instead: clone the source directly and stub the dead import. _MAMBA_SRC = "/home/user/mamba_src" if not os.path.exists(_MAMBA_SRC): subprocess.run( ["git", "clone", "--depth", "1", "https://github.com/state-spaces/mamba", _MAMBA_SRC], check=True, ) open(os.path.join(_MAMBA_SRC, "selective_scan_cuda.py"), "w").close() sys.path.insert(0, _MAMBA_SRC) # verify it actually imports before proceeding from mamba_ssm.modules.mamba3 import Mamba3 # noqa: F401 # --------------------------------------------------- from modeling_mythos import OmniMythosDense, MythosConfig # ---------------- knobs ---------------- REPO = "Johnblick187/TweaktronOmniMythosPrototype" TOKENIZER = "Johnblick187/Tweaktron-Omni-Mythos-Mini" LOOPS = 2 DEFAULT_MAX_TOKENS = 200 # ---------------------------------------- _tok = AutoTokenizer.from_pretrained(TOKENIZER) _cfg = MythosConfig() _cfg.vocab_size = _tok.vocab_size _model_cache = {} # step_label -> loaded model (kept in CPU RAM, moved to GPU per-call) def list_available_steps(): """Discover which stepN.safetensors files actually exist in the repo.""" try: files = list_repo_files(REPO) except Exception as e: return ["latest"], f"Could not list repo files: {e}" steps = [] for f in files: if f.startswith("step") and f.endswith(".safetensors"): label = f.replace(".safetensors", "") steps.append(label) def step_num(label): if label == "latest": return float("inf") try: return int(label.replace("step", "")) except ValueError: return -1 steps = sorted(set(steps), key=step_num) steps.append("latest") return steps, None def load_model(step_label): if step_label in _model_cache: return _model_cache[step_label] filename = f"{step_label}.safetensors" ckpt_path = hf_hub_download(repo_id=REPO, filename=filename) from safetensors.torch import load_file sd = load_file(ckpt_path) model = OmniMythosDense(_cfg).to(torch.bfloat16) missing, unexpected = model.load_state_dict(sd, strict=False) model.eval() _model_cache.clear() # only keep one checkpoint in memory at a time _model_cache[step_label] = model return model @spaces.GPU def generate(prompt, step_label, max_new_tokens, temperature, top_k): if not prompt or not prompt.strip(): return "Enter a prompt first." model = load_model(step_label) model = model.to("cuda") ids = _tok.encode(prompt, add_special_tokens=False) x = torch.tensor([ids], dtype=torch.long, device="cuda") generated = list(ids) with torch.no_grad(): for _ in range(int(max_new_tokens)): logits, _, _ = model(x, n_loops=LOOPS) next_logits = logits[0, -1, :] / max(temperature, 1e-5) if top_k > 0: topk_vals, topk_idx = torch.topk(next_logits, int(top_k)) probs = torch.softmax(topk_vals, dim=-1) next_id = topk_idx[torch.multinomial(probs, 1)].item() else: probs = torch.softmax(next_logits, dim=-1) next_id = torch.multinomial(probs, 1).item() generated.append(next_id) x = torch.tensor([generated], dtype=torch.long, device="cuda") if next_id == _tok.eos_token_id: break model.to("cpu") torch.cuda.empty_cache() return _tok.decode(generated, skip_special_tokens=True) _step_choices, _step_error = list_available_steps() with gr.Blocks(title="Tweaktron: Omni-Mythos") as demo: gr.Markdown("# Tweaktron: Omni-Mythos\nGenerate text from a chosen training checkpoint.") if _step_error: gr.Markdown(f"⚠️ {_step_error}") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", lines=4, placeholder="Once upon a time...") step = gr.Dropdown(choices=_step_choices, value=_step_choices[-1] if _step_choices else "latest", label="Checkpoint step") max_tokens = gr.Slider(10, 500, value=DEFAULT_MAX_TOKENS, step=10, label="Max new tokens") temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature") top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k (0 = disabled)") run_btn = gr.Button("Generate", variant="primary") with gr.Column(): output = gr.Textbox(label="Output", lines=16) run_btn.click(generate, inputs=[prompt, step, max_tokens, temperature, top_k], outputs=output) if __name__ == "__main__": demo.launch()