my-fresh-gen / app.py.old.gpu.work
Javedalam's picture
Rename app.py to app.py.old.gpu.work
2dc865c verified
import os, time, threading
import gradio as gr
import torch, spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = "WeiboAI/VibeThinker-1.5B"
SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer."
MAX_INPUT_TOKENS = 384
MAX_NEW_TOKENS = 128
TEMPERATURE = 0.4
TOP_P = 0.9
NO_TOKEN_TIMEOUT = 8 # seconds with no new token -> stop
print(f"⏳ Loading {MODEL_ID} …", flush=True)
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
low_cpu_mem_usage=True,
dtype=torch.bfloat16, # <- use dtype (not torch_dtype)
device_map="auto",
).eval()
print("✅ Model ready.", flush=True)
def _apply_template(messages):
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def _clip_inputs(prompt_text, max_tokens):
ids = tok([prompt_text], return_tensors="pt")
if ids["input_ids"].shape[-1] > max_tokens:
ids = {k: v[:, -max_tokens:] for k, v in ids.items()}
return {k: v.to(model.device) for k, v in ids.items()}
@spaces.GPU(duration=90)
def respond(message, history):
history = history or []
msgs = [{"role": "system", "content": SYSTEM_PROMPT}, *history,
{"role": "user", "content": str(message)}]
prompt = _apply_template(msgs)
inputs = _clip_inputs(prompt, MAX_INPUT_TOKENS)
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**inputs,
streamer=streamer,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=1.18,
max_new_tokens=MAX_NEW_TOKENS,
pad_token_id=tok.eos_token_id,
use_cache=True,
)
th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
th.start()
assistant = {"role": "assistant", "content": ""}
out = list(history) + [assistant]
last_token_time = time.time()
last_yield = 0
for chunk in streamer:
assistant["content"] += chunk
last_token_time = time.time()
# heartbeat every ~4s so frontend never stalls
now = time.time()
if now - last_yield >= 4:
yield out
last_yield = now
# wait briefly for tail tokens; abort if none arrive
while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_TIMEOUT:
time.sleep(0.5)
yield out
if th.is_alive():
assistant["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_TIMEOUT}s)"
yield out
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (stable streaming)")
chat = gr.Chatbot(type="messages", height=520)
box = gr.Textbox(placeholder="Ask a question…")
send = gr.Button("Send", variant="primary")
def pipeline(msg, hist):
for hist in respond(msg, hist):
yield "", hist
box.submit(pipeline, [box, chat], [box, chat])
send.click(pipeline, [box, chat], [box, chat])
if __name__ == "__main__":
# Gradio 4.x: queue() has no concurrency_count; keep max_size if desired
demo.queue(max_size=16).launch()