Voinal / model.py
GovIndLok
refactor: replace ollama with custom MiniCPM5-1B pipeline and update Gradio interface to support chat history
bfd8f9c
Raw
History Blame Contribute Delete
2.38 kB
import os
# VoxCPM2 torch.compiles a submodule that crashes TorchDynamo on this stack
# ("Cannot construct ConstantVariable for torch.device"); disable compilation so
# it runs eager. Must be set before torch is imported (via spaces / voxcpm).
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
import threading
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import threading
import spaces
MODEL_ID = "openbmb/MiniCPM5-1B"
print(f"[llm] Loading tokenizer for {MODEL_ID} ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
print(f"[llm] Tokenizer loaded in GPU ...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to("cuda")
model.eval()
print("[llm] model is ready", flush=True)
def model_input(messages):
"Tokenize chat messages into model inputs."
kw = dict(tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt")
try:
enc = tokenizer.apply_chat_template(messages, enable_thinking=False, **kw)
except TypeError:
enc = tokenizer.apply_chat_template(messages, **kw)
return enc.to(model.device)
@spaces.GPU(duration=120)
def generate(messages, max_new_tokens: int = 100) -> str:
"One full chat completion (use by blocking path)"
inputs = model_input(messages)
in_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens,pad_token_id=tokenizer.eos_token_id)
return tokenizer.decode(out[0][in_len:], skip_special_tokens=True).strip()
# Test live generation
@spaces.GPU(duration=100)
def generate_stream(messages, max_new_tokens: int = 120):
"Generate lines as miniCPM write it"
inputs = model_input(messages)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(**inputs, streamer=streamer, max_new_tokens=max_new_tokens,pad_token_id=tokenizer.eos_token_id)
def _run():
with torch.no_grad():
model.generate(**kwargs)
threading.Thread(target=_run, daemon=True).start()
acc = ""
for piece in streamer:
acc += piece
yield piece