Generations / app.py
Johnblick187's picture
Update app.py
ee88bdc verified
Raw
History Blame Contribute Delete
5.49 kB
"""app.py — Tweaktron: Omni-Mythos generation Space (ZeroGPU).
Pulls a chosen checkpoint step from the HF model repo and runs generation.
Requires modeling_mythos.py and gdn2.py to be present alongside this file
in the Space repo.
"""
import os
import sys
import subprocess
import torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download, list_repo_files
from transformers import AutoTokenizer
# ---------------- mamba_src setup ----------------
# Mamba3 runs on pure Triton kernels and never touches the compiled
# selective_scan_cuda extension at runtime -- that extension is only used
# by the legacy Mamba1/Mamba2 fallback path. `pip install mamba-ssm` tries
# to build it anyway, which is fragile/slow in a Space build environment.
# Instead: clone the source directly and stub the dead import.
_MAMBA_SRC = "/home/user/mamba_src"
if not os.path.exists(_MAMBA_SRC):
subprocess.run(
["git", "clone", "--depth", "1",
"https://github.com/state-spaces/mamba", _MAMBA_SRC],
check=True,
)
open(os.path.join(_MAMBA_SRC, "selective_scan_cuda.py"), "w").close()
sys.path.insert(0, _MAMBA_SRC)
# verify it actually imports before proceeding
from mamba_ssm.modules.mamba3 import Mamba3 # noqa: F401
# ---------------------------------------------------
from modeling_mythos import OmniMythosDense, MythosConfig
# ---------------- knobs ----------------
REPO = "Johnblick187/TweaktronOmniMythosPrototype"
TOKENIZER = "Johnblick187/Tweaktron-Omni-Mythos-Mini"
LOOPS = 2
DEFAULT_MAX_TOKENS = 200
# ----------------------------------------
_tok = AutoTokenizer.from_pretrained(TOKENIZER)
_cfg = MythosConfig()
_cfg.vocab_size = _tok.vocab_size
_model_cache = {} # step_label -> loaded model (kept in CPU RAM, moved to GPU per-call)
def list_available_steps():
"""Discover which stepN.safetensors files actually exist in the repo."""
try:
files = list_repo_files(REPO)
except Exception as e:
return ["latest"], f"Could not list repo files: {e}"
steps = []
for f in files:
if f.startswith("step") and f.endswith(".safetensors"):
label = f.replace(".safetensors", "")
steps.append(label)
def step_num(label):
if label == "latest":
return float("inf")
try:
return int(label.replace("step", ""))
except ValueError:
return -1
steps = sorted(set(steps), key=step_num)
steps.append("latest")
return steps, None
def load_model(step_label):
if step_label in _model_cache:
return _model_cache[step_label]
filename = f"{step_label}.safetensors"
ckpt_path = hf_hub_download(repo_id=REPO, filename=filename)
from safetensors.torch import load_file
sd = load_file(ckpt_path)
model = OmniMythosDense(_cfg).to(torch.bfloat16)
missing, unexpected = model.load_state_dict(sd, strict=False)
model.eval()
_model_cache.clear() # only keep one checkpoint in memory at a time
_model_cache[step_label] = model
return model
@spaces.GPU
def generate(prompt, step_label, max_new_tokens, temperature, top_k):
if not prompt or not prompt.strip():
return "Enter a prompt first."
model = load_model(step_label)
model = model.to("cuda")
ids = _tok.encode(prompt, add_special_tokens=False)
x = torch.tensor([ids], dtype=torch.long, device="cuda")
generated = list(ids)
with torch.no_grad():
for _ in range(int(max_new_tokens)):
logits, _, _ = model(x, n_loops=LOOPS)
next_logits = logits[0, -1, :] / max(temperature, 1e-5)
if top_k > 0:
topk_vals, topk_idx = torch.topk(next_logits, int(top_k))
probs = torch.softmax(topk_vals, dim=-1)
next_id = topk_idx[torch.multinomial(probs, 1)].item()
else:
probs = torch.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, 1).item()
generated.append(next_id)
x = torch.tensor([generated], dtype=torch.long, device="cuda")
if next_id == _tok.eos_token_id:
break
model.to("cpu")
torch.cuda.empty_cache()
return _tok.decode(generated, skip_special_tokens=True)
_step_choices, _step_error = list_available_steps()
with gr.Blocks(title="Tweaktron: Omni-Mythos") as demo:
gr.Markdown("# Tweaktron: Omni-Mythos\nGenerate text from a chosen training checkpoint.")
if _step_error:
gr.Markdown(f"⚠️ {_step_error}")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=4, placeholder="Once upon a time...")
step = gr.Dropdown(choices=_step_choices,
value=_step_choices[-1] if _step_choices else "latest",
label="Checkpoint step")
max_tokens = gr.Slider(10, 500, value=DEFAULT_MAX_TOKENS, step=10, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature")
top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k (0 = disabled)")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Textbox(label="Output", lines=16)
run_btn.click(generate, inputs=[prompt, step, max_tokens, temperature, top_k], outputs=output)
if __name__ == "__main__":
demo.launch()