File size: 2,489 Bytes
a274bd3 83f4f35 a274bd3 3f6b6a4 50913a6 e86b33d 3f6b6a4 4980ab4 3f6b6a4 50913a6 3f6b6a4 a19113d cfeea86 3f6b6a4 50913a6 83f4f35 a19113d 3f6b6a4 eaca37f a19113d 83f4f35 6fba16a 50913a6 6da3384 83f4f35 3f6b6a4 50913a6 83f4f35 3f6b6a4 50913a6 83f4f35 eaca37f 3f6b6a4 50913a6 83f4f35 3f6b6a4 83f4f35 50913a6 3f6b6a4 a19113d 50913a6 3f6b6a4 7767edb 0326ffd 65ea084 a19113d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import torch
import re
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# 1. 加载模型和分词器
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 2. 流式生成函数
def stream_response(message, history):
# 1. 初始化消息列表并处理历史记录(兼容字典和列表)
messages = [{"role": "system", "content": "你是一个乐于助人的 AI 助手。"}]
for h in history:
if isinstance(h, dict):
messages.append(h)
elif isinstance(h, (list, tuple)):
if h[0]: messages.append({"role": "user", "content": h[0]})
if h[1]: messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": message})
# 2. 编码输入
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(model.device)
# 3. 设置流式器
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 4. 启动生成线程
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": 512,
"do_sample": True,
"temperature": 0.7,
"pad_token_id": tokenizer.eos_token_id,
}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# 5. 流式输出并实时去符号
partial_text = ""
for new_text in streamer:
partial_text += new_text
# --- 自动处理符号逻辑 ---
# 这里的正则示例会去掉:Markdown的列表符(-)、星号(*)、多余的换行(\n)、引号
# 你可以根据需求增删里面的符号
clean_display_text = re.sub(r'[\n\-\*\"”"“#]', '', partial_text)
# 将处理后的纯文本推送给 Gradio
yield clean_display_text
# 3. 界面设置
demo = gr.ChatInterface(
fn=stream_response,
title="Llama-3.2 流式对话助手 ⚡",
# 注意:关闭 cache_examples 以避免启动时的 500 错误
cache_examples=False,
examples=["给我写一首关于春天的诗", "如何用 Python 实现快速排序?"],
)
if __name__ == "__main__":
demo.launch() |