llm_assistant / app.py
caobin's picture
Update app.py
efda0f1 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# -------------------------------
# 模型加载
# -------------------------------
MODEL_ID = "caobin/llm-caobin"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto", # CPU 上自动映射到 CPU
trust_remote_code=True
)
# -------------------------------
# 工具函数:清理历史
# -------------------------------
def clean_history(history):
"""
将历史消息的 content 转为字符串,避免 list 导致空回答
"""
cleaned = []
for msg in history:
content = msg['content']
if isinstance(content, list):
# list -> str
content = " ".join([str(c) for c in content])
cleaned.append({"role": msg['role'], "content": content})
return cleaned
# -------------------------------
# 聊天函数
# -------------------------------
def chat_fn(message, history):
history = clean_history(history)
recent_history = history[-6:] # 保留最近 3 轮对话
full_prompt = ""
for msg in recent_history:
if msg["role"] == "user":
full_prompt += f"<|user|>{msg['content']}<|assistant|>"
elif msg["role"] == "assistant":
full_prompt += msg['content']
# 当前用户问题
full_prompt += f"<|user|>{message}<|assistant|>"
# tokenizer -> tensor
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
# 生成回答
output_ids = model.generate(
**inputs,
max_new_tokens=128,
temperature=0.3,
top_p=0.3,
do_sample=True,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if "<|assistant|>" in output_text:
output_text = output_text.split("<|assistant|>")[-1]
return output_text.strip()
# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(title="caobin LLM Chatbot") as demo:
gr.Markdown("# 🤖 caobin's AI assistant")
chatbot = gr.Chatbot(height=450)
msg = gr.Textbox(label="输入你的问题")
def respond(message, chat_history):
response = chat_fn(message, chat_history)
# 用字典格式添加消息
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": response})
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
# -------------------------------
# 启动
# -------------------------------
demo.launch()