File size: 2,489 Bytes
a274bd3
83f4f35
a274bd3
3f6b6a4
50913a6
e86b33d
3f6b6a4
4980ab4
3f6b6a4
 
 
50913a6
3f6b6a4
a19113d
cfeea86
3f6b6a4
50913a6
83f4f35
a19113d
3f6b6a4
eaca37f
 
 
 
 
a19113d
 
83f4f35
6fba16a
50913a6
 
6da3384
83f4f35
3f6b6a4
50913a6
83f4f35
3f6b6a4
50913a6
83f4f35
eaca37f
 
 
 
 
 
 
 
3f6b6a4
 
50913a6
83f4f35
3f6b6a4
 
 
83f4f35
 
 
 
 
 
 
 
50913a6
3f6b6a4
a19113d
50913a6
 
3f6b6a4
 
7767edb
0326ffd
65ea084
a19113d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import re
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# 1. 加载模型和分词器
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# 2. 流式生成函数
def stream_response(message, history):
    # 1. 初始化消息列表并处理历史记录(兼容字典和列表)
    messages = [{"role": "system", "content": "你是一个乐于助人的 AI 助手。"}]
    for h in history:
        if isinstance(h, dict):
            messages.append(h)
        elif isinstance(h, (list, tuple)):
            if h[0]: messages.append({"role": "user", "content": h[0]})
            if h[1]: messages.append({"role": "assistant", "content": h[1]})
    messages.append({"role": "user", "content": message})

    # 2. 编码输入
    model_inputs = tokenizer.apply_chat_template(
        messages, 
        add_generation_prompt=True, 
        return_tensors="pt",
        return_dict=True
    ).to(model.device)

    # 3. 设置流式器
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    # 4. 启动生成线程
    generate_kwargs = {
        **model_inputs,
        "streamer": streamer,
        "max_new_tokens": 512,
        "do_sample": True,
        "temperature": 0.7,
        "pad_token_id": tokenizer.eos_token_id,
    }
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    # 5. 流式输出并实时去符号
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        
        # --- 自动处理符号逻辑 ---
        # 这里的正则示例会去掉:Markdown的列表符(-)、星号(*)、多余的换行(\n)、引号
        # 你可以根据需求增删里面的符号
        clean_display_text = re.sub(r'[\n\-\*\"”"“#]', '', partial_text) 
        
        # 将处理后的纯文本推送给 Gradio
        yield clean_display_text

# 3. 界面设置
demo = gr.ChatInterface(
    fn=stream_response,
    title="Llama-3.2 流式对话助手 ⚡",
    # 注意:关闭 cache_examples 以避免启动时的 500 错误
    cache_examples=False, 
    examples=["给我写一首关于春天的诗", "如何用 Python 实现快速排序?"],
)

if __name__ == "__main__":
    demo.launch()