Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
@@ -5,46 +9,69 @@ from threading import Thread
|
|
| 5 |
import logging
|
| 6 |
import time
|
| 7 |
import json
|
|
|
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
# 最终版 app.py (适配 Gradio 4.x+ 的推荐模式)
|
| 11 |
-
# ===================================================================
|
| 12 |
|
| 13 |
-
#
|
| 14 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
logger.info("===== Application Startup =====")
|
| 20 |
logger.info(f"正在加载模型和分词器: {MODEL_ID}")
|
|
|
|
|
|
|
| 21 |
try:
|
| 22 |
-
|
|
|
|
| 23 |
model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
MODEL_ID,
|
| 25 |
-
torch_dtype=
|
| 26 |
-
device_map="auto"
|
|
|
|
| 27 |
)
|
| 28 |
logger.info("模型和分词器加载成功!")
|
| 29 |
except Exception as e:
|
| 30 |
logger.error(f"加载模型时发生致命错误: {e}", exc_info=True)
|
|
|
|
| 31 |
model, tokenizer = None, None
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
# --- 2. 核心推理函数
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
"""
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
"""
|
| 38 |
start_time = time.time()
|
| 39 |
logger.info("\n--- [START] 进入 predict 函数 ---")
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
logger.info(f"[INPUT] 收到的 messages 列表:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
|
| 43 |
|
|
|
|
| 44 |
if model is None or tokenizer is None:
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
try:
|
| 50 |
logger.info("[HANDLER] 正在应用聊天模板...")
|
|
@@ -56,52 +83,78 @@ def predict(messages: list):
|
|
| 56 |
).to(model.device)
|
| 57 |
logger.info(f"[HANDLER] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
|
| 58 |
|
|
|
|
| 59 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
| 60 |
-
|
|
|
|
| 61 |
generation_kwargs = {
|
| 62 |
"input_ids": prompt_tokens,
|
| 63 |
"streamer": streamer,
|
| 64 |
"max_new_tokens": 1024,
|
| 65 |
"do_sample": True,
|
| 66 |
"temperature": 0.7,
|
| 67 |
-
"top_p": 0.9
|
| 68 |
}
|
| 69 |
|
| 70 |
-
|
| 71 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 72 |
thread.start()
|
| 73 |
logger.info("[HANDLER] 生成线程已启动,开始从 streamer 中读取数据...")
|
| 74 |
|
|
|
|
| 75 |
buffer = ""
|
| 76 |
token_count = 0
|
| 77 |
for new_text in streamer:
|
| 78 |
token_count += 1
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
buffer += new_text
|
| 81 |
yield buffer
|
| 82 |
|
| 83 |
logger.info(f"[HANDLER] Streamer 读取完毕,共生成 {token_count} 个 token。")
|
|
|
|
| 84 |
|
| 85 |
except Exception as e:
|
| 86 |
logger.error(f"[HANDLER] 在推理过程中发生错误: {e}", exc_info=True)
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
gr.Markdown(f"# 你的自定义Qwen模型聊天机器人\n## 模型: {MODEL_ID}")
|
| 95 |
|
|
|
|
| 96 |
chat_interface = gr.ChatInterface(
|
| 97 |
fn=predict,
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
# --- 4. 启动应用 ---
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
logger.info("Gradio
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ===================================================================
|
| 2 |
+
# 优化版 app.py (为 Gradio 5.x 优化)
|
| 3 |
+
# ===================================================================
|
| 4 |
+
|
| 5 |
import gradio as gr
|
| 6 |
import torch
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
| 9 |
import logging
|
| 10 |
import time
|
| 11 |
import json
|
| 12 |
+
import os
|
| 13 |
|
| 14 |
+
# --- 1. 配置与初始化 ---
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# 日志记录配置
|
| 17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
+
# 从环境变量或默认值加载模型ID,增加灵活性
|
| 21 |
+
MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
|
| 22 |
|
| 23 |
logger.info("===== Application Startup =====")
|
| 24 |
logger.info(f"正在加载模型和分词器: {MODEL_ID}")
|
| 25 |
+
|
| 26 |
+
# 异常处理以优雅地处理模型加载失败
|
| 27 |
try:
|
| 28 |
+
# 推荐使用 trust_remote_code=True 以确保所有模型组件正确加载
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 30 |
model = AutoModelForCausalLM.from_pretrained(
|
| 31 |
MODEL_ID,
|
| 32 |
+
torch_dtype="auto", # 推荐使用 "auto" 以获得最佳性能和兼容性
|
| 33 |
+
device_map="auto",
|
| 34 |
+
trust_remote_code=True
|
| 35 |
)
|
| 36 |
logger.info("模型和分词器加载成功!")
|
| 37 |
except Exception as e:
|
| 38 |
logger.error(f"加载模型时发生致命错误: {e}", exc_info=True)
|
| 39 |
+
# 在无法加载模型时,创建一个占位符函数,以便Gradio界面仍能启动并显示错误
|
| 40 |
model, tokenizer = None, None
|
| 41 |
+
def model_load_error_placeholder(*args, **kwargs):
|
| 42 |
+
raise gr.Error(f"致命错误:无法加载模型 '{MODEL_ID}'。请检查后台日志以获取详细信息。")
|
| 43 |
|
| 44 |
+
# --- 2. 核心推理函数 ---
|
| 45 |
+
|
| 46 |
+
# Gradio 5.x 的 ChatInterface `fn` 函数接收两个参数: message 和 history
|
| 47 |
+
def predict(message: str, history: list[list[str]]):
|
| 48 |
"""
|
| 49 |
+
核心推理函数,接收用户输入和聊天历史,并以流式方式返回模型输出。
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
message (str): 用户的最新输入。
|
| 53 |
+
history (list[list[str]]): 聊天历史,格式为 [[user_msg_1, bot_msg_1], [user_msg_2, bot_msg_2], ...]。
|
| 54 |
"""
|
| 55 |
start_time = time.time()
|
| 56 |
logger.info("\n--- [START] 进入 predict 函数 ---")
|
| 57 |
+
logger.info(f"[INPUT] Message: {message}")
|
| 58 |
+
logger.info(f"[INPUT] History:\n{json.dumps(history, indent=2, ensure_ascii=False)}")
|
|
|
|
| 59 |
|
| 60 |
+
# 如果模型加载失败,使用占位符函数抛出错误
|
| 61 |
if model is None or tokenizer is None:
|
| 62 |
+
model_load_error_placeholder()
|
| 63 |
+
|
| 64 |
+
# 将 Gradio 的 history 格式转换为 Hugging Face 模板所需的格式
|
| 65 |
+
# history 的格式: [[user, assistant], [user, assistant], ...]
|
| 66 |
+
# messages 的格式: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
| 67 |
+
messages = []
|
| 68 |
+
for turn in history:
|
| 69 |
+
user_message, bot_message = turn
|
| 70 |
+
messages.append({"role": "user", "content": user_message})
|
| 71 |
+
messages.append({"role": "assistant", "content": bot_message})
|
| 72 |
+
messages.append({"role": "user", "content": message})
|
| 73 |
+
|
| 74 |
+
logger.info(f"[HANDLER] 转换后的 messages 列表:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
|
| 75 |
|
| 76 |
try:
|
| 77 |
logger.info("[HANDLER] 正在应用聊天模板...")
|
|
|
|
| 83 |
).to(model.device)
|
| 84 |
logger.info(f"[HANDLER] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
|
| 85 |
|
| 86 |
+
# 使用 TextIteratorStreamer 实现流式输出
|
| 87 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
| 88 |
+
|
| 89 |
+
# 定义生成参数
|
| 90 |
generation_kwargs = {
|
| 91 |
"input_ids": prompt_tokens,
|
| 92 |
"streamer": streamer,
|
| 93 |
"max_new_tokens": 1024,
|
| 94 |
"do_sample": True,
|
| 95 |
"temperature": 0.7,
|
| 96 |
+
"top_p": 0.9
|
| 97 |
}
|
| 98 |
|
| 99 |
+
# 在单独的线程中运行模型生成,以避免阻塞UI
|
| 100 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 101 |
thread.start()
|
| 102 |
logger.info("[HANDLER] 生成线程已启动,开始从 streamer 中读取数据...")
|
| 103 |
|
| 104 |
+
# 从 streamer 中逐个 token 地 yield,实现实时流式效果
|
| 105 |
buffer = ""
|
| 106 |
token_count = 0
|
| 107 |
for new_text in streamer:
|
| 108 |
token_count += 1
|
| 109 |
+
if "�" in new_text: # 过滤掉解码不完全的特殊字符
|
| 110 |
+
continue
|
| 111 |
+
logger.debug(f"[STREAM] 正在生成第 {token_count} 个 token: {repr(new_text)}")
|
| 112 |
buffer += new_text
|
| 113 |
yield buffer
|
| 114 |
|
| 115 |
logger.info(f"[HANDLER] Streamer 读取完毕,共生成 {token_count} 个 token。")
|
| 116 |
+
thread.join() # 确保线程执行完毕
|
| 117 |
|
| 118 |
except Exception as e:
|
| 119 |
logger.error(f"[HANDLER] 在推理过程中发生错误: {e}", exc_info=True)
|
| 120 |
+
# 使用 gr.Error 在界面上优雅地显示错误信息
|
| 121 |
+
raise gr.Error(f"抱歉,处理您的请求时遇到了一个内部错误: {e}")
|
| 122 |
|
| 123 |
+
finally:
|
| 124 |
+
end_time = time.time()
|
| 125 |
+
logger.info(f"--- [END] predict 函数结束,总耗时: {end_time - start_time:.2f} 秒 ---")
|
| 126 |
+
|
| 127 |
+
# --- 3. 创建并配置Gradio界面 ---
|
| 128 |
|
| 129 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css="footer {visibility: hidden}") as demo:
|
| 130 |
+
gr.Markdown(f"# 你的自定义Qwen模型聊天机器人\n## 模型: `{MODEL_ID}`")
|
|
|
|
| 131 |
|
| 132 |
+
# gr.ChatInterface 是 Gradio 5.x 中构建聊天机器人的推荐方式
|
| 133 |
chat_interface = gr.ChatInterface(
|
| 134 |
fn=predict,
|
| 135 |
+
# Gradio 5.x 的 `fn` 自动接收 message 和 history,无需手动管理状态
|
| 136 |
+
chatbot=gr.Chatbot(
|
| 137 |
+
height=600,
|
| 138 |
+
show_copy_button=True,
|
| 139 |
+
avatar_images=(None, "https://s2.loli.net/2024/07/17/iPqD3uVgW9eBkbT.png") # (user, bot)
|
| 140 |
+
),
|
| 141 |
+
title="Qwen 大模型聊天室",
|
| 142 |
+
description="向你的微调Qwen模型提问吧!这是一个流式输出的例子。",
|
| 143 |
+
examples=[
|
| 144 |
+
["你好,你是谁?"],
|
| 145 |
+
["用 Python 写一个快速排序算法。"],
|
| 146 |
+
["解释一下什么是大型语言模型(LLM)。"]
|
| 147 |
+
],
|
| 148 |
+
submit_btn="发送",
|
| 149 |
+
retry_btn="🔄 重试",
|
| 150 |
+
undo_btn="↩️ 撤销",
|
| 151 |
+
clear_btn="🗑️ 清除"
|
| 152 |
)
|
| 153 |
|
| 154 |
# --- 4. 启动应用 ---
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
logger.info("准备启动 Gradio 应用...")
|
| 158 |
+
# 使用 queue() 实现请求排队,concurrency_count 控制并发数
|
| 159 |
+
demo.queue(concurrency_count=2).launch(server_name="0.0.0.0", server_port=7860)
|
| 160 |
+
logger.info("Gradio 应用已启动。")
|