Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| # ------------------------------- | |
| # 模型加载 | |
| # ------------------------------- | |
| MODEL_ID = "caobin/llm-caobin" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", # CPU 上自动映射到 CPU | |
| trust_remote_code=True | |
| ) | |
| # ------------------------------- | |
| # 工具函数:清理历史 | |
| # ------------------------------- | |
| def clean_history(history): | |
| """ | |
| 将历史消息的 content 转为字符串,避免 list 导致空回答 | |
| """ | |
| cleaned = [] | |
| for msg in history: | |
| content = msg['content'] | |
| if isinstance(content, list): | |
| # list -> str | |
| content = " ".join([str(c) for c in content]) | |
| cleaned.append({"role": msg['role'], "content": content}) | |
| return cleaned | |
| # ------------------------------- | |
| # 聊天函数 | |
| # ------------------------------- | |
| def chat_fn(message, history): | |
| history = clean_history(history) | |
| recent_history = history[-6:] # 保留最近 3 轮对话 | |
| full_prompt = "" | |
| for msg in recent_history: | |
| if msg["role"] == "user": | |
| full_prompt += f"<|user|>{msg['content']}<|assistant|>" | |
| elif msg["role"] == "assistant": | |
| full_prompt += msg['content'] | |
| # 当前用户问题 | |
| full_prompt += f"<|user|>{message}<|assistant|>" | |
| # tokenizer -> tensor | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| # 生成回答 | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| temperature=0.3, | |
| top_p=0.3, | |
| do_sample=True, | |
| ) | |
| output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| if "<|assistant|>" in output_text: | |
| output_text = output_text.split("<|assistant|>")[-1] | |
| return output_text.strip() | |
| # ------------------------------- | |
| # Gradio UI | |
| # ------------------------------- | |
| with gr.Blocks(title="caobin LLM Chatbot") as demo: | |
| gr.Markdown("# 🤖 caobin's AI assistant") | |
| chatbot = gr.Chatbot(height=450) | |
| msg = gr.Textbox(label="输入你的问题") | |
| def respond(message, chat_history): | |
| response = chat_fn(message, chat_history) | |
| # 用字典格式添加消息 | |
| chat_history.append({"role": "user", "content": message}) | |
| chat_history.append({"role": "assistant", "content": response}) | |
| return "", chat_history | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| # ------------------------------- | |
| # 启动 | |
| # ------------------------------- | |
| demo.launch() |