File size: 2,384 Bytes
5b38e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfd8f9c
5b38e09
 
 
 
 
 
 
 
 
 
 
 
 
bfd8f9c
5b38e09
 
 
 
 
 
 
 
bfd8f9c
5b38e09
bfd8f9c
5b38e09
 
 
 
 
 
 
 
bfd8f9c
5b38e09
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os 
# VoxCPM2 torch.compiles a submodule that crashes TorchDynamo on this stack
# ("Cannot construct ConstantVariable for torch.device"); disable compilation so
# it runs eager. Must be set before torch is imported (via spaces / voxcpm).
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
import threading
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import threading
import spaces

MODEL_ID = "openbmb/MiniCPM5-1B"

print(f"[llm] Loading tokenizer for {MODEL_ID} ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

print(f"[llm] Tokenizer loaded in  GPU ...", flush=True)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
).to("cuda")
model.eval()
print("[llm] model is ready", flush=True)

def model_input(messages):
    "Tokenize chat messages into model inputs."
    kw = dict(tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt")

    try:
        enc = tokenizer.apply_chat_template(messages, enable_thinking=False, **kw)
    except TypeError:
        enc = tokenizer.apply_chat_template(messages, **kw)
    return enc.to(model.device)

@spaces.GPU(duration=120)
def generate(messages, max_new_tokens: int = 100) -> str:
    "One full chat completion (use by blocking path)"
    inputs = model_input(messages)
    in_len = inputs["input_ids"].shape[-1]
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens,pad_token_id=tokenizer.eos_token_id)

    return tokenizer.decode(out[0][in_len:], skip_special_tokens=True).strip()


# Test live generation
@spaces.GPU(duration=100)
def generate_stream(messages, max_new_tokens: int = 120):
    "Generate lines as miniCPM write it"
    inputs = model_input(messages)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    kwargs = dict(**inputs, streamer=streamer, max_new_tokens=max_new_tokens,pad_token_id=tokenizer.eos_token_id)

    def _run():
        with torch.no_grad():
            model.generate(**kwargs)
    
    threading.Thread(target=_run, daemon=True).start()
    acc = ""

    for piece in streamer:
        acc += piece
        yield piece