File size: 3,273 Bytes
7ba7d6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os, time, threading
import gradio as gr
import torch, spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MODEL_ID = "WeiboAI/VibeThinker-1.5B"
SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer."
MAX_INPUT_TOKENS = 384
MAX_NEW_TOKENS   = 128
TEMPERATURE      = 0.4
TOP_P            = 0.9
NO_TOKEN_TIMEOUT = 8  # seconds with no new token -> stop

print(f"⏳ Loading {MODEL_ID} …", flush=True)
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    dtype=torch.bfloat16,    # <- use dtype (not torch_dtype)
    device_map="auto",
).eval()
print("✅ Model ready.", flush=True)

def _apply_template(messages):
    return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def _clip_inputs(prompt_text, max_tokens):
    ids = tok([prompt_text], return_tensors="pt")
    if ids["input_ids"].shape[-1] > max_tokens:
        ids = {k: v[:, -max_tokens:] for k, v in ids.items()}
    return {k: v.to(model.device) for k, v in ids.items()}

@spaces.GPU(duration=90)
def respond(message, history):
    history = history or []
    msgs = [{"role": "system", "content": SYSTEM_PROMPT}, *history,
            {"role": "user", "content": str(message)}]

    prompt = _apply_template(msgs)
    inputs = _clip_inputs(prompt, MAX_INPUT_TOKENS)

    streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs = dict(
        **inputs,
        streamer=streamer,
        do_sample=True,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        repetition_penalty=1.18,
        max_new_tokens=MAX_NEW_TOKENS,
        pad_token_id=tok.eos_token_id,
        use_cache=True,
    )

    th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
    th.start()

    assistant = {"role": "assistant", "content": ""}
    out = list(history) + [assistant]

    last_token_time = time.time()
    last_yield = 0

    for chunk in streamer:
        assistant["content"] += chunk
        last_token_time = time.time()
        # heartbeat every ~4s so frontend never stalls
        now = time.time()
        if now - last_yield >= 4:
            yield out
            last_yield = now

    # wait briefly for tail tokens; abort if none arrive
    while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_TIMEOUT:
        time.sleep(0.5)
        yield out

    if th.is_alive():
        assistant["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_TIMEOUT}s)"
    yield out

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (stable streaming)")
    chat = gr.Chatbot(type="messages", height=520)
    box  = gr.Textbox(placeholder="Ask a question…")
    send = gr.Button("Send", variant="primary")

    def pipeline(msg, hist):
        for hist in respond(msg, hist):
            yield "", hist

    box.submit(pipeline, [box, chat], [box, chat])
    send.click(pipeline, [box, chat], [box, chat])

if __name__ == "__main__":
    # Gradio 4.x: queue() has no concurrency_count; keep max_size if desired
    demo.queue(max_size=16).launch()