import gradio as gr from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch import os # --- 配置 --- MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b" # --- 加载模型和分词器 --- print("开始加载模型和分词器...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="auto", trust_remote_code=True ) print("模型和分词器加载成功!") except Exception as e: print(f"模型加载失败: {e}") raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}") # --- 核心对话函数 --- def predict(message, history): messages = [] for turn in history: user_msg, assistant_msg = turn messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) model_inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs=model_inputs, streamer=streamer, max_new_tokens=2048, do_sample=True, temperature=0.7, top_p=0.95, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() full_response = "" for new_text in streamer: full_response += new_text yield full_response # --- 创建并启动Gradio界面 --- # 已移除 examples 和 cache_examples 参数来修复点击示例时报错的问题 demo = gr.ChatInterface( fn=predict, title="小Q老师 - 基础问答 (本地加载)", description=f"直接在Space中运行 {MODEL_ID} 模型进行流式对话。CPU推理可能较慢,请耐心等待。", ) if __name__ == "__main__": # 使用 share=True 来允许跨域 WebSocket 连接 demo.launch(share=True)