Spaces:
Running
on
Zero
Running
on
Zero
| import os, time, threading | |
| import gradio as gr | |
| import torch, spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| MODEL_ID = "WeiboAI/VibeThinker-1.5B" | |
| SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer." | |
| MAX_INPUT_TOKENS = 384 | |
| MAX_NEW_TOKENS = 128 | |
| TEMPERATURE = 0.4 | |
| TOP_P = 0.9 | |
| NO_TOKEN_TIMEOUT = 8 # seconds with no new token -> stop | |
| print(f"⏳ Loading {MODEL_ID} …", flush=True) | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| dtype=torch.bfloat16, # <- use dtype (not torch_dtype) | |
| device_map="auto", | |
| ).eval() | |
| print("✅ Model ready.", flush=True) | |
| def _apply_template(messages): | |
| return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| def _clip_inputs(prompt_text, max_tokens): | |
| ids = tok([prompt_text], return_tensors="pt") | |
| if ids["input_ids"].shape[-1] > max_tokens: | |
| ids = {k: v[:, -max_tokens:] for k, v in ids.items()} | |
| return {k: v.to(model.device) for k, v in ids.items()} | |
| @spaces.GPU(duration=90) | |
| def respond(message, history): | |
| history = history or [] | |
| msgs = [{"role": "system", "content": SYSTEM_PROMPT}, *history, | |
| {"role": "user", "content": str(message)}] | |
| prompt = _apply_template(msgs) | |
| inputs = _clip_inputs(prompt, MAX_INPUT_TOKENS) | |
| streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| do_sample=True, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| repetition_penalty=1.18, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| pad_token_id=tok.eos_token_id, | |
| use_cache=True, | |
| ) | |
| th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) | |
| th.start() | |
| assistant = {"role": "assistant", "content": ""} | |
| out = list(history) + [assistant] | |
| last_token_time = time.time() | |
| last_yield = 0 | |
| for chunk in streamer: | |
| assistant["content"] += chunk | |
| last_token_time = time.time() | |
| # heartbeat every ~4s so frontend never stalls | |
| now = time.time() | |
| if now - last_yield >= 4: | |
| yield out | |
| last_yield = now | |
| # wait briefly for tail tokens; abort if none arrive | |
| while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_TIMEOUT: | |
| time.sleep(0.5) | |
| yield out | |
| if th.is_alive(): | |
| assistant["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_TIMEOUT}s)" | |
| yield out | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (stable streaming)") | |
| chat = gr.Chatbot(type="messages", height=520) | |
| box = gr.Textbox(placeholder="Ask a question…") | |
| send = gr.Button("Send", variant="primary") | |
| def pipeline(msg, hist): | |
| for hist in respond(msg, hist): | |
| yield "", hist | |
| box.submit(pipeline, [box, chat], [box, chat]) | |
| send.click(pipeline, [box, chat], [box, chat]) | |
| if __name__ == "__main__": | |
| # Gradio 4.x: queue() has no concurrency_count; keep max_size if desired | |
| demo.queue(max_size=16).launch() |