mac
rewrite ui
d70f24f
# MiniCPM5-1B Demo
import os
import logging
import threading
from typing import Generator
import spaces
import torch
from fastapi.responses import HTMLResponse
from gradio import Server
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from utils_chatbot import organize_messages
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_PATH = "openbmb/MiniCPM5-1B"
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
logger.info("Logged in to Hugging Face Hub")
else:
logger.warning("HF_TOKEN not set β€” private/gated models will be inaccessible")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to("cuda")
demo = Server()
@demo.api()
@spaces.GPU(duration=60)
def predict(
message: str,
history: list[list] | None = None,
thinking_mode: bool = True,
temperature: float = 0.9,
top_p: float = 0.95,
) -> Generator[str, None, None]:
messages = organize_messages(message, history)
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=thinking_mode,
)
model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=False,
)
gen_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=4096,
)
if temperature > 0:
gen_kwargs.update(temperature=temperature, top_p=top_p, do_sample=True)
else:
gen_kwargs.update(do_sample=False)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
full_text = ""
for new_token_text in streamer:
if not new_token_text:
continue
full_text += new_token_text
yield full_text
thread.join()
@demo.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
if __name__ == "__main__":
demo.launch(show_error=True)