File size: 2,664 Bytes
31dd87b
22dabde
55d16ab
31dd87b
22dabde
 
 
a90f54e
a622589
22dabde
55d16ab
 
22dabde
a622589
55d16ab
f41e6d1
22dabde
 
 
 
073fff5
22dabde
073fff5
22dabde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efda0f1
 
22dabde
55d16ab
073fff5
22dabde
3a8e995
 
 
f41e6d1
22dabde
 
 
a90f54e
22dabde
4e7b745
 
55d16ab
3a8e995
 
22dabde
 
 
3a8e995
 
4e7b745
31dd87b
22dabde
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# -------------------------------
# 模型加载
# -------------------------------
MODEL_ID = "caobin/llm-caobin"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",  # CPU 上自动映射到 CPU
    trust_remote_code=True
)

# -------------------------------
# 工具函数:清理历史
# -------------------------------
def clean_history(history):
    """
    将历史消息的 content 转为字符串,避免 list 导致空回答
    """
    cleaned = []
    for msg in history:
        content = msg['content']
        if isinstance(content, list):
            # list -> str
            content = " ".join([str(c) for c in content])
        cleaned.append({"role": msg['role'], "content": content})
    return cleaned

# -------------------------------
# 聊天函数
# -------------------------------
def chat_fn(message, history):
    history = clean_history(history)
    recent_history = history[-6:]  # 保留最近 3 轮对话
    full_prompt = ""
    
    for msg in recent_history:
        if msg["role"] == "user":
            full_prompt += f"<|user|>{msg['content']}<|assistant|>"
        elif msg["role"] == "assistant":
            full_prompt += msg['content']
    
    # 当前用户问题
    full_prompt += f"<|user|>{message}<|assistant|>"

    # tokenizer -> tensor
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)

    # 生成回答
    output_ids = model.generate(
        **inputs,
        max_new_tokens=128,
        temperature=0.3,
        top_p=0.3,
        do_sample=True,
    )

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if "<|assistant|>" in output_text:
        output_text = output_text.split("<|assistant|>")[-1]
    return output_text.strip()

# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(title="caobin LLM Chatbot") as demo:
    gr.Markdown("# 🤖 caobin's AI assistant")
    chatbot = gr.Chatbot(height=450)
    msg = gr.Textbox(label="输入你的问题")

    def respond(message, chat_history):
        response = chat_fn(message, chat_history)
        # 用字典格式添加消息
        chat_history.append({"role": "user", "content": message})
        chat_history.append({"role": "assistant", "content": response})
        return "", chat_history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])

# -------------------------------
# 启动
# -------------------------------
demo.launch()