Spaces:
Sleeping
Sleeping
File size: 2,664 Bytes
31dd87b 22dabde 55d16ab 31dd87b 22dabde a90f54e a622589 22dabde 55d16ab 22dabde a622589 55d16ab f41e6d1 22dabde 073fff5 22dabde 073fff5 22dabde efda0f1 22dabde 55d16ab 073fff5 22dabde 3a8e995 f41e6d1 22dabde a90f54e 22dabde 4e7b745 55d16ab 3a8e995 22dabde 3a8e995 4e7b745 31dd87b 22dabde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# -------------------------------
# 模型加载
# -------------------------------
MODEL_ID = "caobin/llm-caobin"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto", # CPU 上自动映射到 CPU
trust_remote_code=True
)
# -------------------------------
# 工具函数:清理历史
# -------------------------------
def clean_history(history):
"""
将历史消息的 content 转为字符串,避免 list 导致空回答
"""
cleaned = []
for msg in history:
content = msg['content']
if isinstance(content, list):
# list -> str
content = " ".join([str(c) for c in content])
cleaned.append({"role": msg['role'], "content": content})
return cleaned
# -------------------------------
# 聊天函数
# -------------------------------
def chat_fn(message, history):
history = clean_history(history)
recent_history = history[-6:] # 保留最近 3 轮对话
full_prompt = ""
for msg in recent_history:
if msg["role"] == "user":
full_prompt += f"<|user|>{msg['content']}<|assistant|>"
elif msg["role"] == "assistant":
full_prompt += msg['content']
# 当前用户问题
full_prompt += f"<|user|>{message}<|assistant|>"
# tokenizer -> tensor
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
# 生成回答
output_ids = model.generate(
**inputs,
max_new_tokens=128,
temperature=0.3,
top_p=0.3,
do_sample=True,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if "<|assistant|>" in output_text:
output_text = output_text.split("<|assistant|>")[-1]
return output_text.strip()
# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(title="caobin LLM Chatbot") as demo:
gr.Markdown("# 🤖 caobin's AI assistant")
chatbot = gr.Chatbot(height=450)
msg = gr.Textbox(label="输入你的问题")
def respond(message, chat_history):
response = chat_fn(message, chat_history)
# 用字典格式添加消息
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": response})
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
# -------------------------------
# 启动
# -------------------------------
demo.launch() |