Javedalam commited on
Commit
7ba7d6e
·
verified ·
1 Parent(s): ef477f4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, threading
2
+ import gradio as gr
3
+ import torch, spaces
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+
6
+ MODEL_ID = "WeiboAI/VibeThinker-1.5B"
7
+ SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer."
8
+ MAX_INPUT_TOKENS = 384
9
+ MAX_NEW_TOKENS = 128
10
+ TEMPERATURE = 0.4
11
+ TOP_P = 0.9
12
+ NO_TOKEN_TIMEOUT = 8 # seconds with no new token -> stop
13
+
14
+ print(f"⏳ Loading {MODEL_ID} …", flush=True)
15
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL_ID,
18
+ trust_remote_code=True,
19
+ low_cpu_mem_usage=True,
20
+ dtype=torch.bfloat16, # <- use dtype (not torch_dtype)
21
+ device_map="auto",
22
+ ).eval()
23
+ print("✅ Model ready.", flush=True)
24
+
25
+ def _apply_template(messages):
26
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
27
+
28
+ def _clip_inputs(prompt_text, max_tokens):
29
+ ids = tok([prompt_text], return_tensors="pt")
30
+ if ids["input_ids"].shape[-1] > max_tokens:
31
+ ids = {k: v[:, -max_tokens:] for k, v in ids.items()}
32
+ return {k: v.to(model.device) for k, v in ids.items()}
33
+
34
+ @spaces.GPU(duration=90)
35
+ def respond(message, history):
36
+ history = history or []
37
+ msgs = [{"role": "system", "content": SYSTEM_PROMPT}, *history,
38
+ {"role": "user", "content": str(message)}]
39
+
40
+ prompt = _apply_template(msgs)
41
+ inputs = _clip_inputs(prompt, MAX_INPUT_TOKENS)
42
+
43
+ streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
44
+ gen_kwargs = dict(
45
+ **inputs,
46
+ streamer=streamer,
47
+ do_sample=True,
48
+ temperature=TEMPERATURE,
49
+ top_p=TOP_P,
50
+ repetition_penalty=1.18,
51
+ max_new_tokens=MAX_NEW_TOKENS,
52
+ pad_token_id=tok.eos_token_id,
53
+ use_cache=True,
54
+ )
55
+
56
+ th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
57
+ th.start()
58
+
59
+ assistant = {"role": "assistant", "content": ""}
60
+ out = list(history) + [assistant]
61
+
62
+ last_token_time = time.time()
63
+ last_yield = 0
64
+
65
+ for chunk in streamer:
66
+ assistant["content"] += chunk
67
+ last_token_time = time.time()
68
+ # heartbeat every ~4s so frontend never stalls
69
+ now = time.time()
70
+ if now - last_yield >= 4:
71
+ yield out
72
+ last_yield = now
73
+
74
+ # wait briefly for tail tokens; abort if none arrive
75
+ while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_TIMEOUT:
76
+ time.sleep(0.5)
77
+ yield out
78
+
79
+ if th.is_alive():
80
+ assistant["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_TIMEOUT}s)"
81
+ yield out
82
+
83
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
+ gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (stable streaming)")
85
+ chat = gr.Chatbot(type="messages", height=520)
86
+ box = gr.Textbox(placeholder="Ask a question…")
87
+ send = gr.Button("Send", variant="primary")
88
+
89
+ def pipeline(msg, hist):
90
+ for hist in respond(msg, hist):
91
+ yield "", hist
92
+
93
+ box.submit(pipeline, [box, chat], [box, chat])
94
+ send.click(pipeline, [box, chat], [box, chat])
95
+
96
+ if __name__ == "__main__":
97
+ # Gradio 4.x: queue() has no concurrency_count; keep max_size if desired
98
+ demo.queue(max_size=16).launch()