llm / app.py
han145's picture
Update app.py
83f4f35 verified
import torch
import re
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# 1. 加载模型和分词器
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 2. 流式生成函数
def stream_response(message, history):
# 1. 初始化消息列表并处理历史记录(兼容字典和列表)
messages = [{"role": "system", "content": "你是一个乐于助人的 AI 助手。"}]
for h in history:
if isinstance(h, dict):
messages.append(h)
elif isinstance(h, (list, tuple)):
if h[0]: messages.append({"role": "user", "content": h[0]})
if h[1]: messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": message})
# 2. 编码输入
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(model.device)
# 3. 设置流式器
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 4. 启动生成线程
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": 512,
"do_sample": True,
"temperature": 0.7,
"pad_token_id": tokenizer.eos_token_id,
}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# 5. 流式输出并实时去符号
partial_text = ""
for new_text in streamer:
partial_text += new_text
# --- 自动处理符号逻辑 ---
# 这里的正则示例会去掉:Markdown的列表符(-)、星号(*)、多余的换行(\n)、引号
# 你可以根据需求增删里面的符号
clean_display_text = re.sub(r'[\n\-\*\"”"“#]', '', partial_text)
# 将处理后的纯文本推送给 Gradio
yield clean_display_text
# 3. 界面设置
demo = gr.ChatInterface(
fn=stream_response,
title="Llama-3.2 流式对话助手 ⚡",
# 注意:关闭 cache_examples 以避免启动时的 500 错误
cache_examples=False,
examples=["给我写一首关于春天的诗", "如何用 Python 实现快速排序?"],
)
if __name__ == "__main__":
demo.launch()