File size: 2,201 Bytes
ecb7812 105ed0f 7aee45a d00fffc 9eb5e7a 105ed0f 43e2809 105ed0f 43e2809 105ed0f c4614a3 9eb5e7a c4614a3 9eb5e7a c4614a3 105ed0f 43e2809 c4614a3 d3329dd c4614a3 105ed0f c4614a3 105ed0f c4614a3 43e2809 c4614a3 43e2809 105ed0f c4614a3 43e2809 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import os
# --- 配置 ---
MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
# --- 加载模型和分词器 ---
print("开始加载模型和分词器...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
print("模型和分词器加载成功!")
except Exception as e:
print(f"模型加载失败: {e}")
raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")
# --- 核心对话函数 ---
def predict(message, history):
messages = []
for turn in history:
user_msg, assistant_msg = turn
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs=model_inputs,
streamer=streamer,
max_new_tokens=2048,
do_sample=True,
temperature=0.7,
top_p=0.95,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_response = ""
for new_text in streamer:
full_response += new_text
yield full_response
# --- 创建并启动Gradio界面 ---
# 已移除 examples 和 cache_examples 参数来修复点击示例时报错的问题
demo = gr.ChatInterface(
fn=predict,
title="小Q老师 - 基础问答 (本地加载)",
description=f"直接在Space中运行 {MODEL_ID} 模型进行流式对话。CPU推理可能较慢,请耐心等待。",
)
if __name__ == "__main__":
# 使用 share=True 来允许跨域 WebSocket 连接
demo.launch(share=True) |