File size: 4,277 Bytes
328e8d9
 
 
 
 
 
 
 
80a9976
328e8d9
 
80a9976
 
328e8d9
 
 
80a9976
 
 
 
328e8d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80a9976
 
 
 
 
 
 
 
 
 
 
 
 
328e8d9
 
 
 
 
 
 
 
 
 
 
 
80a9976
 
 
 
328e8d9
 
 
 
 
 
80a9976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328e8d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import time

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


MODEL_ID = os.environ.get("MODEL_ID", "openbmb/MiniCPM5-1B-SFT")

SYSTEM_NOTE = (
    "MiniCPM5-1B is a text-only language model. Local validation is currently cleanest for English, Chinese, "
    "code snippets with explicit constraints, and tool-planning prompts. Persian and native Arabic are not marked supported yet."
)

EXAMPLES = [
    ["Briefly introduce yourself as a local AI assistant in two sentences.", 96, 0.2, 0.95],
    ["请用中文用三点总结:为什么本地小模型对隐私有帮助?", 160, 0.3, 0.95],
    ["Return only Python code. Write count_jsonl_rows(path) that counts lines in a JSONL file without using json.load.", 160, 0.2, 0.95],
    ["Give exactly two numbered steps to inspect a local README and summarize it safely. Do not say you cannot inspect files; write the tool-use plan.", 192, 0.2, 0.95],
]


tokenizer = None
model = None


def load_model():
    global tokenizer, model
    if model is not None:
        return
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="auto" if torch.cuda.is_available() else None,
    ).eval()


def generate(prompt, max_new_tokens, temperature, top_p):
    if not prompt.strip():
        return "Enter a prompt first.", ""
    load_model()
    start = time.time()
    rendered = tokenizer.apply_chat_template(
        [
            {
                "role": "system",
                "content": "Answer directly and concisely. Do not include hidden reasoning or thinking process text.",
            },
            {"role": "user", "content": prompt},
        ],
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )
    inputs = tokenizer(rendered, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    do_sample = temperature > 0
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            temperature=float(temperature) if do_sample else None,
            top_p=float(top_p) if do_sample else None,
            do_sample=do_sample,
            pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
        )
    text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if "</think>" in text:
        text = text.split("</think>", 1)[1].strip()
    elif rendered in text:
        text = text.split(rendered, 1)[1].strip()
    new_tokens = max(0, output_ids.shape[-1] - inputs["input_ids"].shape[-1])
    elapsed = max(time.time() - start, 1e-6)
    metrics = f"{new_tokens} new tokens | {new_tokens / elapsed:.2f} tok/s | {elapsed:.2f}s | model: {MODEL_ID}"
    return text, metrics


css = """
.status-box {
  border: 1px solid #d8dee8;
  border-radius: 8px;
  padding: 12px 14px;
  background: #f8fafc;
  color: #263244;
}
.status-box strong {
  color: #101827;
}
"""


with gr.Blocks(title="MiniCPM5-1B Chat", theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown("# MiniCPM5-1B Chat")
    gr.HTML(f"<div class='status-box'><strong>Validation status:</strong> {SYSTEM_NOTE}<br><strong>Runtime model:</strong> {MODEL_ID}</div>")
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=8, value=EXAMPLES[0][0])
            run = gr.Button("Generate", variant="primary")
        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(16, 512, value=128, step=1, label="Max new tokens")
            temperature = gr.Slider(0, 1.5, value=0.2, step=0.05, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
    output = gr.Textbox(label="Output", lines=14)
    metrics = gr.Textbox(label="Run metrics", interactive=False)
    gr.Examples(EXAMPLES, inputs=[prompt, max_new_tokens, temperature, top_p])
    run.click(generate, inputs=[prompt, max_new_tokens, temperature, top_p], outputs=[output, metrics])


if __name__ == "__main__":
    demo.launch()