nomos-1-zerogpu / app.py
GravityShares's picture
Deploy Nomos ZeroGPU app
353f0fe verified
#!/usr/bin/env python3
import os
import threading
from typing import Any
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
except Exception:
class _SpacesFallback:
@staticmethod
def GPU(duration: int = 60):
def _decorator(fn):
return fn
return _decorator
spaces = _SpacesFallback()
DEFAULT_FULL_MODEL = "NousResearch/nomos-1"
DEFAULT_MODEL_CANDIDATES = "cyankiwi/nomos-1-AWQ-8bit,cyankiwi/nomos-1-AWQ-4bit"
GPU_DURATION_SECONDS = int(os.getenv("GPU_DURATION_SECONDS", "120"))
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "2048"))
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256"))
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true"
PREFER_FULL = os.getenv("PREFER_FULL", "false").lower() == "true"
_MODEL_LOCK = threading.Lock()
_MODEL: Any = None
_TOKENIZER: Any = None
_MODEL_ID: str | None = None
_LOAD_ERRORS: list[str] = []
def _ordered_candidates() -> list[str]:
configured = os.getenv("MODEL_CANDIDATES", DEFAULT_MODEL_CANDIDATES)
candidates = [m.strip() for m in configured.split(",") if m.strip()]
if PREFER_FULL and DEFAULT_FULL_MODEL not in candidates:
candidates = [DEFAULT_FULL_MODEL] + candidates
return candidates
def _load_model_if_needed() -> tuple[str | None, str]:
global _MODEL, _TOKENIZER, _MODEL_ID
if _MODEL is not None and _TOKENIZER is not None and _MODEL_ID is not None:
return _MODEL_ID, "model already loaded"
with _MODEL_LOCK:
if _MODEL is not None and _TOKENIZER is not None and _MODEL_ID is not None:
return _MODEL_ID, "model already loaded"
errors: list[str] = []
for candidate in _ordered_candidates():
try:
tokenizer = AutoTokenizer.from_pretrained(
candidate,
trust_remote_code=TRUST_REMOTE_CODE,
)
model = AutoModelForCausalLM.from_pretrained(
candidate,
device_map="auto",
trust_remote_code=TRUST_REMOTE_CODE,
low_cpu_mem_usage=True,
)
model.eval()
_TOKENIZER = tokenizer
_MODEL = model
_MODEL_ID = candidate
_LOAD_ERRORS.clear()
return candidate, "loaded"
except Exception as exc:
errors.append(f"{candidate}: {type(exc).__name__}: {exc}")
_LOAD_ERRORS[:] = errors
return None, "load failed"
def _status_text() -> str:
candidates = ", ".join(_ordered_candidates())
loaded = _MODEL_ID or "none"
base = (
f"Loaded model: `{loaded}`\n\n"
f"Candidates: `{candidates}`\n\n"
f"GPU duration: `{GPU_DURATION_SECONDS}s` | "
f"Max input tokens: `{MAX_INPUT_TOKENS}`"
)
if _LOAD_ERRORS:
err = "\n".join(f"- {e}" for e in _LOAD_ERRORS[-3:])
return base + "\n\nRecent load errors:\n" + err
return base
@spaces.GPU(duration=GPU_DURATION_SECONDS)
def generate(
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
do_sample: bool,
) -> tuple[str, str]:
prompt = (prompt or "").strip()
if not prompt:
return "Provide a prompt.", _status_text()
model_id, _ = _load_model_if_needed()
if model_id is None:
return "Model load failed. Check status and Space logs.", _status_text()
tokenizer = _TOKENIZER
model = _MODEL
messages = [{"role": "user", "content": prompt}]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
if input_ids.shape[-1] > MAX_INPUT_TOKENS:
input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
gen_kwargs: dict[str, Any] = {
"input_ids": input_ids,
"max_new_tokens": int(max_new_tokens),
"do_sample": bool(do_sample),
"pad_token_id": tokenizer.eos_token_id,
}
if do_sample:
gen_kwargs.update(
{
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
}
)
with torch.no_grad():
output_ids = model.generate(**gen_kwargs)
generated_ids = output_ids[0][input_ids.shape[-1]:]
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
if not text:
text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
return text, _status_text()
with gr.Blocks(title="Nomos ZeroGPU Inference") as demo:
gr.Markdown(
"# Nomos Remote Inference (ZeroGPU)\n"
"This app tries model candidates in order and keeps the first that loads."
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(
label="Prompt",
lines=10,
placeholder="Ask for a concise proof or solution sketch...",
)
with gr.Row():
max_new_tokens = gr.Slider(
minimum=32,
maximum=1024,
value=MAX_NEW_TOKENS_DEFAULT,
step=1,
label="Max new tokens",
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=20,
step=1,
label="Top-k",
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0,
maximum=1.5,
value=0.6,
step=0.01,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.05,
maximum=1.0,
value=0.95,
step=0.01,
label="Top-p",
)
do_sample = gr.Checkbox(value=True, label="Sample")
run_btn = gr.Button("Generate")
with gr.Column(scale=2):
output = gr.Textbox(label="Output", lines=18)
status = gr.Markdown(value=_status_text())
run_btn.click(
fn=generate,
inputs=[prompt, max_new_tokens, temperature, top_p, top_k, do_sample],
outputs=[output, status],
api_name="generate",
)
gr.Examples(
examples=[
["Solve: Find all integers n such that n^2 + n + 1 is prime."],
["Give a proof sketch that there are infinitely many primes."],
],
inputs=prompt,
)
demo.queue(max_size=32)
if __name__ == "__main__":
demo.launch()