File size: 4,248 Bytes
5c97114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import time, threading
import gradio as gr
import torch, spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# ---- Config ----
MODEL_ID = "WeiboAI/VibeThinker-1.5B"
SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer."

MAX_INPUT_TOKENS   = 384       # cap prompt length so first token comes fast
MAX_NEW_TOKENS     = 96        # keep inside ZeroGPU slice
DO_SAMPLE          = False     # deterministic decode = faster/steadier on ZeroGPU
TEMPERATURE        = 0.4       # used only if DO_SAMPLE=True
TOP_P              = 0.9
FIRST_TOKEN_TIMEOUT = 3        # if no token in 3s -> likely no worker slot
NO_TOKEN_HANG_CUTOFF = 8       # safety if stream stalls mid-gen

print(f"⏳ Loading {MODEL_ID} …", flush=True)
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    dtype=torch.bfloat16,      # (use dtype, not torch_dtype)
    device_map="auto",
).eval()
print("✅ Model ready.", flush=True)


def _prepare_inputs(messages):
    prompt_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    ids = tok([prompt_text], return_tensors="pt")
    # clip to keep within MAX_INPUT_TOKENS
    if ids["input_ids"].shape[-1] > MAX_INPUT_TOKENS:
        ids = {k: v[:, -MAX_INPUT_TOKENS:] for k, v in ids.items()}
    return {k: v.to(model.device) for k, v in ids.items()}


@spaces.GPU(duration=60)   # request a short ZeroGPU slice (more likely to schedule)
def respond(user_message, history):
    history = history or []
    msgs = [{"role": "system", "content": SYSTEM_PROMPT},
            *history,
            {"role": "user", "content": str(user_message)}]

    inputs = _prepare_inputs(msgs)

    # fine-grained streaming
    streamer = TextIteratorStreamer(
        tok, skip_prompt=True, skip_special_tokens=True, timeout=0.05
    )

    gen_kwargs = dict(
        **inputs,
        streamer=streamer,
        do_sample=DO_SAMPLE,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        repetition_penalty=1.15,          # tame short loops
        max_new_tokens=MAX_NEW_TOKENS,
        pad_token_id=tok.eos_token_id,
        eos_token_id=tok.eos_token_id,
        use_cache=True,
    )

    # run generate in a daemon thread so it never blocks future calls
    th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
    th.start()

    out = list(history) + [{"role": "assistant", "content": ""}]
    got_first = False
    start = time.time()
    last_token_time = start

    try:
        for chunk in streamer:
            got_first = True
            last_token_time = time.time()
            out[-1]["content"] += chunk
            # yield every token (true streaming)
            yield out

        # safety: if thread still alive but no tokens arriving for a while, stop nicely
        while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_HANG_CUTOFF:
            time.sleep(0.25)
            yield out

        if th.is_alive():
            out[-1]["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_HANG_CUTOFF}s)"
            yield out

        # if we never got a token, tell the user it was likely a ZeroGPU miss
        if not got_first and (time.time() - start) >= FIRST_TOKEN_TIMEOUT:
            out[-1]["content"] = "(No ZeroGPU worker slot yet — press Send again.)"
            yield out

    except Exception as e:
        out[-1]["content"] = f"⚠️ ZeroGPU worker error: {e}"
        yield out


# ---- UI ----
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (smooth streaming)")

    chat = gr.Chatbot(type="messages", height=520)  # no 'streaming' kwarg (not in your build)
    box  = gr.Textbox(placeholder="Ask a question…")
    send = gr.Button("Send", variant="primary")

    def pipeline(msg, hist):
        # generator -> stream into Chatbot
        for hist in respond(msg, hist):
            yield "", hist

    box.submit(pipeline, [box, chat], [box, chat])
    send.click(pipeline, [box, chat], [box, chat])

if __name__ == "__main__":
    demo.queue(max_size=16).launch()