my-fresh-gen / app.py
Javedalam's picture
Create app.py
5c97114 verified
raw
history blame
4.25 kB
import time, threading
import gradio as gr
import torch, spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# ---- Config ----
MODEL_ID = "WeiboAI/VibeThinker-1.5B"
SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer."
MAX_INPUT_TOKENS = 384 # cap prompt length so first token comes fast
MAX_NEW_TOKENS = 96 # keep inside ZeroGPU slice
DO_SAMPLE = False # deterministic decode = faster/steadier on ZeroGPU
TEMPERATURE = 0.4 # used only if DO_SAMPLE=True
TOP_P = 0.9
FIRST_TOKEN_TIMEOUT = 3 # if no token in 3s -> likely no worker slot
NO_TOKEN_HANG_CUTOFF = 8 # safety if stream stalls mid-gen
print(f"⏳ Loading {MODEL_ID} …", flush=True)
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
low_cpu_mem_usage=True,
dtype=torch.bfloat16, # (use dtype, not torch_dtype)
device_map="auto",
).eval()
print("✅ Model ready.", flush=True)
def _prepare_inputs(messages):
prompt_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
ids = tok([prompt_text], return_tensors="pt")
# clip to keep within MAX_INPUT_TOKENS
if ids["input_ids"].shape[-1] > MAX_INPUT_TOKENS:
ids = {k: v[:, -MAX_INPUT_TOKENS:] for k, v in ids.items()}
return {k: v.to(model.device) for k, v in ids.items()}
@spaces.GPU(duration=60) # request a short ZeroGPU slice (more likely to schedule)
def respond(user_message, history):
history = history or []
msgs = [{"role": "system", "content": SYSTEM_PROMPT},
*history,
{"role": "user", "content": str(user_message)}]
inputs = _prepare_inputs(msgs)
# fine-grained streaming
streamer = TextIteratorStreamer(
tok, skip_prompt=True, skip_special_tokens=True, timeout=0.05
)
gen_kwargs = dict(
**inputs,
streamer=streamer,
do_sample=DO_SAMPLE,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=1.15, # tame short loops
max_new_tokens=MAX_NEW_TOKENS,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
use_cache=True,
)
# run generate in a daemon thread so it never blocks future calls
th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
th.start()
out = list(history) + [{"role": "assistant", "content": ""}]
got_first = False
start = time.time()
last_token_time = start
try:
for chunk in streamer:
got_first = True
last_token_time = time.time()
out[-1]["content"] += chunk
# yield every token (true streaming)
yield out
# safety: if thread still alive but no tokens arriving for a while, stop nicely
while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_HANG_CUTOFF:
time.sleep(0.25)
yield out
if th.is_alive():
out[-1]["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_HANG_CUTOFF}s)"
yield out
# if we never got a token, tell the user it was likely a ZeroGPU miss
if not got_first and (time.time() - start) >= FIRST_TOKEN_TIMEOUT:
out[-1]["content"] = "(No ZeroGPU worker slot yet — press Send again.)"
yield out
except Exception as e:
out[-1]["content"] = f"⚠️ ZeroGPU worker error: {e}"
yield out
# ---- UI ----
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (smooth streaming)")
chat = gr.Chatbot(type="messages", height=520) # no 'streaming' kwarg (not in your build)
box = gr.Textbox(placeholder="Ask a question…")
send = gr.Button("Send", variant="primary")
def pipeline(msg, hist):
# generator -> stream into Chatbot
for hist in respond(msg, hist):
yield "", hist
box.submit(pipeline, [box, chat], [box, chat])
send.click(pipeline, [box, chat], [box, chat])
if __name__ == "__main__":
demo.queue(max_size=16).launch()