Spaces:
Running
on
Zero
Running
on
Zero
| import time, threading | |
| import gradio as gr | |
| import torch, spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| # ---- Config ---- | |
| MODEL_ID = "WeiboAI/VibeThinker-1.5B" | |
| SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer." | |
| MAX_INPUT_TOKENS = 384 # cap prompt length so first token comes fast | |
| MAX_NEW_TOKENS = 96 # keep inside ZeroGPU slice | |
| DO_SAMPLE = False # deterministic decode = faster/steadier on ZeroGPU | |
| TEMPERATURE = 0.4 # used only if DO_SAMPLE=True | |
| TOP_P = 0.9 | |
| FIRST_TOKEN_TIMEOUT = 3 # if no token in 3s -> likely no worker slot | |
| NO_TOKEN_HANG_CUTOFF = 8 # safety if stream stalls mid-gen | |
| print(f"⏳ Loading {MODEL_ID} …", flush=True) | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| dtype=torch.bfloat16, # (use dtype, not torch_dtype) | |
| device_map="auto", | |
| ).eval() | |
| print("✅ Model ready.", flush=True) | |
| def _prepare_inputs(messages): | |
| prompt_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| ids = tok([prompt_text], return_tensors="pt") | |
| # clip to keep within MAX_INPUT_TOKENS | |
| if ids["input_ids"].shape[-1] > MAX_INPUT_TOKENS: | |
| ids = {k: v[:, -MAX_INPUT_TOKENS:] for k, v in ids.items()} | |
| return {k: v.to(model.device) for k, v in ids.items()} | |
| # request a short ZeroGPU slice (more likely to schedule) | |
| def respond(user_message, history): | |
| history = history or [] | |
| msgs = [{"role": "system", "content": SYSTEM_PROMPT}, | |
| *history, | |
| {"role": "user", "content": str(user_message)}] | |
| inputs = _prepare_inputs(msgs) | |
| # fine-grained streaming | |
| streamer = TextIteratorStreamer( | |
| tok, skip_prompt=True, skip_special_tokens=True, timeout=0.05 | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| do_sample=DO_SAMPLE, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| repetition_penalty=1.15, # tame short loops | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| pad_token_id=tok.eos_token_id, | |
| eos_token_id=tok.eos_token_id, | |
| use_cache=True, | |
| ) | |
| # run generate in a daemon thread so it never blocks future calls | |
| th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) | |
| th.start() | |
| out = list(history) + [{"role": "assistant", "content": ""}] | |
| got_first = False | |
| start = time.time() | |
| last_token_time = start | |
| try: | |
| for chunk in streamer: | |
| got_first = True | |
| last_token_time = time.time() | |
| out[-1]["content"] += chunk | |
| # yield every token (true streaming) | |
| yield out | |
| # safety: if thread still alive but no tokens arriving for a while, stop nicely | |
| while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_HANG_CUTOFF: | |
| time.sleep(0.25) | |
| yield out | |
| if th.is_alive(): | |
| out[-1]["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_HANG_CUTOFF}s)" | |
| yield out | |
| # if we never got a token, tell the user it was likely a ZeroGPU miss | |
| if not got_first and (time.time() - start) >= FIRST_TOKEN_TIMEOUT: | |
| out[-1]["content"] = "(No ZeroGPU worker slot yet — press Send again.)" | |
| yield out | |
| except Exception as e: | |
| out[-1]["content"] = f"⚠️ ZeroGPU worker error: {e}" | |
| yield out | |
| # ---- UI ---- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (smooth streaming)") | |
| chat = gr.Chatbot(type="messages", height=520) # no 'streaming' kwarg (not in your build) | |
| box = gr.Textbox(placeholder="Ask a question…") | |
| send = gr.Button("Send", variant="primary") | |
| def pipeline(msg, hist): | |
| # generator -> stream into Chatbot | |
| for hist in respond(msg, hist): | |
| yield "", hist | |
| box.submit(pipeline, [box, chat], [box, chat]) | |
| send.click(pipeline, [box, chat], [box, chat]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=16).launch() |