Update app.py
Browse files
app.py
CHANGED
|
@@ -2,82 +2,106 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 4 |
from threading import Thread
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
|
| 8 |
|
| 9 |
-
|
|
|
|
| 10 |
try:
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 12 |
-
|
| 13 |
-
# --- 【关键修改 1】---
|
| 14 |
-
# 移除 torch_dtype=torch.bfloat16,使用默认的 float32,这是在 CPU 上最稳妥的选择。
|
| 15 |
model = AutoModelForCausalLM.from_pretrained(
|
| 16 |
MODEL_ID,
|
|
|
|
| 17 |
device_map="auto"
|
| 18 |
)
|
| 19 |
-
|
| 20 |
except Exception as e:
|
| 21 |
-
|
| 22 |
model, tokenizer = None, None
|
| 23 |
|
| 24 |
-
# ---
|
| 25 |
-
def predict(
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
if model is None or tokenizer is None:
|
| 31 |
-
|
| 32 |
-
yield "
|
| 33 |
return
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
chat_history_for_model,
|
| 45 |
-
add_generation_prompt=True,
|
| 46 |
-
tokenize=True,
|
| 47 |
-
return_tensors="pt"
|
| 48 |
-
).to(model.device)
|
| 49 |
-
print(f"[DEBUG] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
|
| 50 |
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
|
|
|
| 76 |
|
| 77 |
-
#
|
| 78 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 79 |
-
#
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
)
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 4 |
from threading import Thread
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
|
| 9 |
+
# ===================================================================
|
| 10 |
+
# 最终版 app.py (适配 Gradio 4.x+ 的推荐模式)
|
| 11 |
+
# ===================================================================
|
| 12 |
+
|
| 13 |
+
# 1. 配置详细的日志记录
|
| 14 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"
|
| 18 |
|
| 19 |
+
logger.info("===== Application Startup =====")
|
| 20 |
+
logger.info(f"正在加载模型和分词器: {MODEL_ID}")
|
| 21 |
try:
|
| 22 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
|
|
|
|
|
|
|
| 23 |
model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
MODEL_ID,
|
| 25 |
+
torch_dtype=torch.float32, # 在 CPU 上使用 float32 以获得最佳稳定性和兼容性
|
| 26 |
device_map="auto"
|
| 27 |
)
|
| 28 |
+
logger.info("模型和分词器加载成功!")
|
| 29 |
except Exception as e:
|
| 30 |
+
logger.error(f"加载模型时发生致命错误: {e}", exc_info=True)
|
| 31 |
model, tokenizer = None, None
|
| 32 |
|
| 33 |
+
# --- 2. 核心推理函数 (已根据 Gradio 新模式重构) ---
|
| 34 |
+
def predict(messages: list):
|
| 35 |
+
"""
|
| 36 |
+
接收一个包含完整对话历史的 OpenAI 格式列表,返回模型的流式响应。
|
| 37 |
+
"""
|
| 38 |
+
start_time = time.time()
|
| 39 |
+
logger.info("\n--- [START] 进入 predict 函数 ---")
|
| 40 |
+
|
| 41 |
+
# 使用 json.dumps 美化输出,方便阅读
|
| 42 |
+
logger.info(f"[INPUT] 收到的 messages 列表:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
|
| 43 |
|
| 44 |
if model is None or tokenizer is None:
|
| 45 |
+
logger.warning("[HANDLER] 模型或分词器为 None,返回错误信息。")
|
| 46 |
+
yield "错误:模型未能成功加载,请检查后台日志。"
|
| 47 |
return
|
| 48 |
|
| 49 |
+
try:
|
| 50 |
+
logger.info("[HANDLER] 正在应用聊天模板...")
|
| 51 |
+
prompt_tokens = tokenizer.apply_chat_template(
|
| 52 |
+
messages,
|
| 53 |
+
add_generation_prompt=True,
|
| 54 |
+
tokenize=True,
|
| 55 |
+
return_tensors="pt"
|
| 56 |
+
).to(model.device)
|
| 57 |
+
logger.info(f"[HANDLER] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
| 60 |
|
| 61 |
+
generation_kwargs = {
|
| 62 |
+
"input_ids": prompt_tokens,
|
| 63 |
+
"streamer": streamer,
|
| 64 |
+
"max_new_tokens": 1024,
|
| 65 |
+
"do_sample": True,
|
| 66 |
+
"temperature": 0.7,
|
| 67 |
+
"top_p": 0.9,
|
| 68 |
+
}
|
| 69 |
|
| 70 |
+
logger.info("[HANDLER] 准备启动生成线程...")
|
| 71 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 72 |
+
thread.start()
|
| 73 |
+
logger.info("[HANDLER] 生成线程已启动,开始从 streamer 中读取数据...")
|
| 74 |
|
| 75 |
+
buffer = ""
|
| 76 |
+
token_count = 0
|
| 77 |
+
for new_text in streamer:
|
| 78 |
+
token_count += 1
|
| 79 |
+
logger.info(f"[STREAM] 正在生成第 {token_count} 个 token: {repr(new_text)}")
|
| 80 |
+
buffer += new_text
|
| 81 |
+
yield buffer
|
| 82 |
+
|
| 83 |
+
logger.info(f"[HANDLER] Streamer 读取完毕,共生成 {token_count} 个 token。")
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"[HANDLER] 在推理过程中发生错误: {e}", exc_info=True)
|
| 87 |
+
yield "抱歉,处理您的请求时遇到了一个内部错误。"
|
| 88 |
|
| 89 |
+
end_time = time.time()
|
| 90 |
+
logger.info(f"--- [END] predict 函数结束,总耗时: {end_time - start_time:.2f} 秒 ---")
|
| 91 |
|
| 92 |
+
# --- 3. 创建Gradio界面 (已优化) ---
|
| 93 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 94 |
+
gr.Markdown(f"# 你的自定义Qwen模型聊天机器人\n## 模型: {MODEL_ID}")
|
| 95 |
+
|
| 96 |
+
chat_interface = gr.ChatInterface(
|
| 97 |
+
fn=predict,
|
| 98 |
+
title="聊天机器人",
|
| 99 |
+
description="向你的微调Qwen模型提问吧!",
|
| 100 |
+
examples=[["你好,你是谁?"], ["用Python写一个快速排序算法"]],
|
| 101 |
+
type="messages" # <-- 【最关键的优化】告诉 Gradio 使用新的 OpenAI 格式
|
| 102 |
)
|
| 103 |
+
|
| 104 |
+
# --- 4. 启动应用 ---
|
| 105 |
+
logger.info("准备启动 Gradio 应用...")
|
| 106 |
+
demo.queue().launch()
|
| 107 |
+
logger.info("Gradio 应用已启动。")
|