Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import os | |
| # --- 1. 配置与模型加载 --- | |
| # 假设运行环境的硬件资源是充足的。 | |
| MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b") | |
| print(f"INFO: 正在加载模型: {MODEL_ID}") | |
| # 使用 try-except 来捕获任何可能的加载错误 (例如网络问题、模型名称错误等) | |
| try: | |
| # 加载分词器和模型 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| # device_map="auto" 会自动利用可用的硬件 (如 CPU 或 GPU) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", # 自动选择最佳数据类型 | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print("INFO: 模型和分词器加载成功!") | |
| # 将核心推理逻辑定义为一个函数 | |
| # 只有在模型成功加载后,这个函数才会被有效定义 | |
| def predict(prompt: str, history: list[list[str]]): | |
| """ | |
| 接收用户输入和对话历史,返回更新后的完整对话历史。 | |
| Gradio 会自动为这个函数创建 API 端点。 | |
| """ | |
| print(f"INFO: 收到API/UI请求: prompt='{prompt}'") | |
| # 1. 构建符合模型要求的消息列表 | |
| messages = [] | |
| for user_message, bot_message in history: | |
| messages.append({"role": "user", "content": user_message}) | |
| messages.append({"role": "assistant", "content": bot_message}) | |
| messages.append({"role": "user", "content": prompt}) | |
| # 2. 应用聊天模板并进行分词 | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| # 3. 生成回复 | |
| # 使用简单的 .generate(),不加额外的采样参数以保持简洁 | |
| outputs = model.generate(input_ids, max_new_tokens=1024) | |
| # 4. 解码生成的文本,跳过输入的token | |
| response_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True) | |
| print(f"INFO: 生成回复: {response_text}") | |
| # 5. 更新并返回对话历史 | |
| history.append([prompt, response_text]) | |
| return history | |
| except Exception as e: | |
| print(f"FATAL: 加载模型或分词器时发生致命错误: {e}") | |
| # 如果模型加载失败,则定义一个专门用于报错的函数 | |
| # 这能确保Gradio界面依然可以启动,并向用户显示一个清晰的错误信息 | |
| def predict(*args, **kwargs): | |
| raise gr.Error(f"模型未能加载,应用无法工作。请检查后台日志获取详细错误信息。错误: {e}") | |
| # --- 2. 创建并启动 Gradio 应用 --- | |
| # 使用 gr.Blocks 来自定义界面布局 | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown(f"## 模型聊天机器人\n当前模型: `{MODEL_ID}`") | |
| # 定义聊天机器人组件和输入框 | |
| chatbot = gr.Chatbot(label="对话历史", height=600) | |
| msg_input = gr.Textbox(label="在这里输入你的问题...", placeholder="例如:你好,你是谁?") | |
| clear_button = gr.Button("清除对话") | |
| # 设定组件的交互逻辑 | |
| # 当用户在输入框中按回车时,调用 predict 函数 | |
| msg_input.submit(predict, [msg_input, chatbot], chatbot) | |
| # 当用户点击“清除对话”按钮时,清空聊天机器人组件 | |
| clear_button.click(lambda: [], None, chatbot) | |
| # --- 3. 启动应用并开放API --- | |
| print("INFO: 准备启动Gradio应用...") | |
| # .queue() 使应用能够处理多个排队的请求,并且在 4.29.0 版本中会自动开放API。 | |
| # share=True 是解决CORS问题的关键。它会生成一个公开的、已配置好CORS的 .gradio.live 网址。 | |
| # *** 已移除 'api_open=True' 参数以适配 gradio==4.29.0 *** | |
| demo.queue().launch(share=True) |