Spaces:
Running
Running
| import os | |
| import time | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_ID = os.environ.get("MODEL_ID", "openbmb/MiniCPM5-1B-SFT") | |
| SYSTEM_NOTE = ( | |
| "MiniCPM5-1B is a text-only language model. Local validation is currently cleanest for English, Chinese, " | |
| "code snippets with explicit constraints, and tool-planning prompts. Persian and native Arabic are not marked supported yet." | |
| ) | |
| EXAMPLES = [ | |
| ["Briefly introduce yourself as a local AI assistant in two sentences.", 96, 0.2, 0.95], | |
| ["请用中文用三点总结:为什么本地小模型对隐私有帮助?", 160, 0.3, 0.95], | |
| ["Return only Python code. Write count_jsonl_rows(path) that counts lines in a JSONL file without using json.load.", 160, 0.2, 0.95], | |
| ["Give exactly two numbered steps to inspect a local README and summarize it safely. Do not say you cannot inspect files; write the tool-use plan.", 192, 0.2, 0.95], | |
| ] | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| global tokenizer, model | |
| if model is not None: | |
| return | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ).eval() | |
| def generate(prompt, max_new_tokens, temperature, top_p): | |
| if not prompt.strip(): | |
| return "Enter a prompt first.", "" | |
| load_model() | |
| start = time.time() | |
| rendered = tokenizer.apply_chat_template( | |
| [ | |
| { | |
| "role": "system", | |
| "content": "Answer directly and concisely. Do not include hidden reasoning or thinking process text.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| inputs = tokenizer(rendered, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| do_sample = temperature > 0 | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature) if do_sample else None, | |
| top_p=float(top_p) if do_sample else None, | |
| do_sample=do_sample, | |
| pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, | |
| ) | |
| text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| if "</think>" in text: | |
| text = text.split("</think>", 1)[1].strip() | |
| elif rendered in text: | |
| text = text.split(rendered, 1)[1].strip() | |
| new_tokens = max(0, output_ids.shape[-1] - inputs["input_ids"].shape[-1]) | |
| elapsed = max(time.time() - start, 1e-6) | |
| metrics = f"{new_tokens} new tokens | {new_tokens / elapsed:.2f} tok/s | {elapsed:.2f}s | model: {MODEL_ID}" | |
| return text, metrics | |
| css = """ | |
| .status-box { | |
| border: 1px solid #d8dee8; | |
| border-radius: 8px; | |
| padding: 12px 14px; | |
| background: #f8fafc; | |
| color: #263244; | |
| } | |
| .status-box strong { | |
| color: #101827; | |
| } | |
| """ | |
| with gr.Blocks(title="MiniCPM5-1B Chat", theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown("# MiniCPM5-1B Chat") | |
| gr.HTML(f"<div class='status-box'><strong>Validation status:</strong> {SYSTEM_NOTE}<br><strong>Runtime model:</strong> {MODEL_ID}</div>") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt = gr.Textbox(label="Prompt", lines=8, value=EXAMPLES[0][0]) | |
| run = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| max_new_tokens = gr.Slider(16, 512, value=128, step=1, label="Max new tokens") | |
| temperature = gr.Slider(0, 1.5, value=0.2, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| output = gr.Textbox(label="Output", lines=14) | |
| metrics = gr.Textbox(label="Run metrics", interactive=False) | |
| gr.Examples(EXAMPLES, inputs=[prompt, max_new_tokens, temperature, top_p]) | |
| run.click(generate, inputs=[prompt, max_new_tokens, temperature, top_p], outputs=[output, metrics]) | |
| if __name__ == "__main__": | |
| demo.launch() | |