my-fresh-gen / app.py.old.gpu.work
Javedalam's picture
Rename app.py to app.py.old.gpu.work
2dc865c verified
raw
history blame
3.27 kB
import os, time, threading
import gradio as gr
import torch, spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = "WeiboAI/VibeThinker-1.5B"
SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer."
MAX_INPUT_TOKENS = 384
MAX_NEW_TOKENS = 128
TEMPERATURE = 0.4
TOP_P = 0.9
NO_TOKEN_TIMEOUT = 8 # seconds with no new token -> stop
print(f"⏳ Loading {MODEL_ID} …", flush=True)
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
low_cpu_mem_usage=True,
dtype=torch.bfloat16, # <- use dtype (not torch_dtype)
device_map="auto",
).eval()
print("✅ Model ready.", flush=True)
def _apply_template(messages):
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def _clip_inputs(prompt_text, max_tokens):
ids = tok([prompt_text], return_tensors="pt")
if ids["input_ids"].shape[-1] > max_tokens:
ids = {k: v[:, -max_tokens:] for k, v in ids.items()}
return {k: v.to(model.device) for k, v in ids.items()}
@spaces.GPU(duration=90)
def respond(message, history):
history = history or []
msgs = [{"role": "system", "content": SYSTEM_PROMPT}, *history,
{"role": "user", "content": str(message)}]
prompt = _apply_template(msgs)
inputs = _clip_inputs(prompt, MAX_INPUT_TOKENS)
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**inputs,
streamer=streamer,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=1.18,
max_new_tokens=MAX_NEW_TOKENS,
pad_token_id=tok.eos_token_id,
use_cache=True,
)
th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
th.start()
assistant = {"role": "assistant", "content": ""}
out = list(history) + [assistant]
last_token_time = time.time()
last_yield = 0
for chunk in streamer:
assistant["content"] += chunk
last_token_time = time.time()
# heartbeat every ~4s so frontend never stalls
now = time.time()
if now - last_yield >= 4:
yield out
last_yield = now
# wait briefly for tail tokens; abort if none arrive
while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_TIMEOUT:
time.sleep(0.5)
yield out
if th.is_alive():
assistant["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_TIMEOUT}s)"
yield out
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (stable streaming)")
chat = gr.Chatbot(type="messages", height=520)
box = gr.Textbox(placeholder="Ask a question…")
send = gr.Button("Send", variant="primary")
def pipeline(msg, hist):
for hist in respond(msg, hist):
yield "", hist
box.submit(pipeline, [box, chat], [box, chat])
send.click(pipeline, [box, chat], [box, chat])
if __name__ == "__main__":
# Gradio 4.x: queue() has no concurrency_count; keep max_size if desired
demo.queue(max_size=16).launch()