Spaces:
Running on Zero
Running on Zero
| """app.py — Tweaktron: Omni-Mythos generation Space (ZeroGPU). | |
| Pulls a chosen checkpoint step from the HF model repo and runs generation. | |
| Requires modeling_mythos.py and gdn2.py to be present alongside this file | |
| in the Space repo. | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from transformers import AutoTokenizer | |
| # ---------------- mamba_src setup ---------------- | |
| # Mamba3 runs on pure Triton kernels and never touches the compiled | |
| # selective_scan_cuda extension at runtime -- that extension is only used | |
| # by the legacy Mamba1/Mamba2 fallback path. `pip install mamba-ssm` tries | |
| # to build it anyway, which is fragile/slow in a Space build environment. | |
| # Instead: clone the source directly and stub the dead import. | |
| _MAMBA_SRC = "/home/user/mamba_src" | |
| if not os.path.exists(_MAMBA_SRC): | |
| subprocess.run( | |
| ["git", "clone", "--depth", "1", | |
| "https://github.com/state-spaces/mamba", _MAMBA_SRC], | |
| check=True, | |
| ) | |
| open(os.path.join(_MAMBA_SRC, "selective_scan_cuda.py"), "w").close() | |
| sys.path.insert(0, _MAMBA_SRC) | |
| # verify it actually imports before proceeding | |
| from mamba_ssm.modules.mamba3 import Mamba3 # noqa: F401 | |
| # --------------------------------------------------- | |
| from modeling_mythos import OmniMythosDense, MythosConfig | |
| # ---------------- knobs ---------------- | |
| REPO = "Johnblick187/TweaktronOmniMythosPrototype" | |
| TOKENIZER = "Johnblick187/Tweaktron-Omni-Mythos-Mini" | |
| LOOPS = 2 | |
| DEFAULT_MAX_TOKENS = 200 | |
| # ---------------------------------------- | |
| _tok = AutoTokenizer.from_pretrained(TOKENIZER) | |
| _cfg = MythosConfig() | |
| _cfg.vocab_size = _tok.vocab_size | |
| _model_cache = {} # step_label -> loaded model (kept in CPU RAM, moved to GPU per-call) | |
| def list_available_steps(): | |
| """Discover which stepN.safetensors files actually exist in the repo.""" | |
| try: | |
| files = list_repo_files(REPO) | |
| except Exception as e: | |
| return ["latest"], f"Could not list repo files: {e}" | |
| steps = [] | |
| for f in files: | |
| if f.startswith("step") and f.endswith(".safetensors"): | |
| label = f.replace(".safetensors", "") | |
| steps.append(label) | |
| def step_num(label): | |
| if label == "latest": | |
| return float("inf") | |
| try: | |
| return int(label.replace("step", "")) | |
| except ValueError: | |
| return -1 | |
| steps = sorted(set(steps), key=step_num) | |
| steps.append("latest") | |
| return steps, None | |
| def load_model(step_label): | |
| if step_label in _model_cache: | |
| return _model_cache[step_label] | |
| filename = f"{step_label}.safetensors" | |
| ckpt_path = hf_hub_download(repo_id=REPO, filename=filename) | |
| from safetensors.torch import load_file | |
| sd = load_file(ckpt_path) | |
| model = OmniMythosDense(_cfg).to(torch.bfloat16) | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| model.eval() | |
| _model_cache.clear() # only keep one checkpoint in memory at a time | |
| _model_cache[step_label] = model | |
| return model | |
| def generate(prompt, step_label, max_new_tokens, temperature, top_k): | |
| if not prompt or not prompt.strip(): | |
| return "Enter a prompt first." | |
| model = load_model(step_label) | |
| model = model.to("cuda") | |
| ids = _tok.encode(prompt, add_special_tokens=False) | |
| x = torch.tensor([ids], dtype=torch.long, device="cuda") | |
| generated = list(ids) | |
| with torch.no_grad(): | |
| for _ in range(int(max_new_tokens)): | |
| logits, _, _ = model(x, n_loops=LOOPS) | |
| next_logits = logits[0, -1, :] / max(temperature, 1e-5) | |
| if top_k > 0: | |
| topk_vals, topk_idx = torch.topk(next_logits, int(top_k)) | |
| probs = torch.softmax(topk_vals, dim=-1) | |
| next_id = topk_idx[torch.multinomial(probs, 1)].item() | |
| else: | |
| probs = torch.softmax(next_logits, dim=-1) | |
| next_id = torch.multinomial(probs, 1).item() | |
| generated.append(next_id) | |
| x = torch.tensor([generated], dtype=torch.long, device="cuda") | |
| if next_id == _tok.eos_token_id: | |
| break | |
| model.to("cpu") | |
| torch.cuda.empty_cache() | |
| return _tok.decode(generated, skip_special_tokens=True) | |
| _step_choices, _step_error = list_available_steps() | |
| with gr.Blocks(title="Tweaktron: Omni-Mythos") as demo: | |
| gr.Markdown("# Tweaktron: Omni-Mythos\nGenerate text from a chosen training checkpoint.") | |
| if _step_error: | |
| gr.Markdown(f"⚠️ {_step_error}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", lines=4, placeholder="Once upon a time...") | |
| step = gr.Dropdown(choices=_step_choices, | |
| value=_step_choices[-1] if _step_choices else "latest", | |
| label="Checkpoint step") | |
| max_tokens = gr.Slider(10, 500, value=DEFAULT_MAX_TOKENS, step=10, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature") | |
| top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k (0 = disabled)") | |
| run_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Output", lines=16) | |
| run_btn.click(generate, inputs=[prompt, step, max_tokens, temperature, top_k], outputs=output) | |
| if __name__ == "__main__": | |
| demo.launch() | |