badanwang's picture
Update app.py
d3329dd verified
raw
history blame
2.93 kB
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import os
# --- 配置 ---
# 我们不再需要API Token,因为模型在本地运行
MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
print("开始加载模型和分词器...")
try:
# 确保使用 trust_remote_code=True,因为Qwen模型需要加载自定义代码
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto", # 使用适合CPU的类型,如torch.float32
device_map="auto", # 自动将模型加载到可用设备(这里是CPU)
trust_remote_code=True
)
print("模型和分词器加载成功!")
except Exception as e:
print(f"模型加载失败: {e}")
# 如果模型加载失败,应用将无法工作,这里可以抛出异常或退出
raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")
# --- 核心对话函数 ---
def predict(message, history):
"""
主函数,使用加载到本地的模型进行流式对话。
"""
# 1. 格式化对话历史
# Qwen的模板要求一个特殊的列表格式
messages = []
for turn in history:
user_msg, assistant_msg = turn
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# 使用分词器的 apply_chat_template 方法来正确格式化输入
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device) # 确保输入张量和模型在同一设备上
# 2. 设置流式输出
streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
# 3. 在一个单独的线程中运行生成,以避免阻塞UI
generation_kwargs = dict(
inputs=model_inputs,
streamer=streamer,
max_new_tokens=2048,
do_sample=True,
temperature=0.7,
top_p=0.95,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 4. 从streamer中yield每个新生成的token
full_response = ""
for new_text in streamer:
full_response += new_text
yield full_response
# --- 创建并启动Gradio界面 ---
demo = gr.ChatInterface(
fn=predict,
title="小Q老师 - 基础问答 (本地加载)",
description=f"直接在Space中运行 {MODEL_ID} 模型进行流式对话。CPU推理可能较慢,请耐心等待。",
examples=[["你好"], ["请用python写一个快速排序算法"], ["给我讲个笑话吧"]],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()