badanwang commited on
Commit
ed1d652
·
verified ·
1 Parent(s): 548ffa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -119
app.py CHANGED
@@ -1,131 +1,94 @@
1
- import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
- from threading import Thread
5
- import logging
6
- import time
7
- import json
8
  import os
9
 
 
10
 
11
- # 日志记录配置
12
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
13
- logger = logging.getLogger(__name__)
14
-
15
- # 从环境变量或默认值加载模型ID,增加灵活性
16
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
17
-
18
- logger.info("===== Application Startup =====")
19
- logger.info(f"正在加载模型和分词器: {MODEL_ID}")
20
-
21
- try:
22
- # 推荐使用 trust_remote_code=True 以确保所有模型组件正确加载
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
24
- model = AutoModelForCausalLM.from_pretrained(
25
- MODEL_ID,
26
- torch_dtype="auto", # 推荐使用 "auto" 以获得最佳性能和兼容性
27
- device_map="auto",
28
- trust_remote_code=True
29
- )
30
- logger.info("模型和分词器加载成功!")
31
- except Exception as e:
32
- logger.error(f"加载模型时发生致命错误: {e}", exc_info=True)
33
- # 在无法加载模型时,创建一个占位符函数,以便Gradio界面仍能启动并显示错误
34
- model, tokenizer = None, None
35
- def model_load_error_placeholder(*args, **kwargs):
36
- raise gr.Error(f"致命错误:无法加载模型 '{MODEL_ID}'。请检查后台日志以获取详细信息。")
37
-
38
- # --- 2. 核心推理函数 (无变动) ---
39
-
40
- def predict(message: str, history: list[list[str]]):
41
- start_time = time.time()
42
- logger.info("\n--- [START] 进入 predict 函数 ---")
43
- logger.info(f"[INPUT] Message: {message}")
44
- logger.info(f"[INPUT] History:\n{json.dumps(history, indent=2, ensure_ascii=False)}")
45
-
46
- if model is None or tokenizer is None:
47
- model_load_error_placeholder()
48
-
49
  messages = []
50
- for turn in history:
51
- user_message, bot_message = turn
52
  messages.append({"role": "user", "content": user_message})
53
  messages.append({"role": "assistant", "content": bot_message})
54
- messages.append({"role": "user", "content": message})
55
-
56
- logger.info(f"[HANDLER] 转换后的 messages 列表:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
57
-
58
- try:
59
- logger.info("[HANDLER] 正在应用聊天模板...")
60
- prompt_tokens = tokenizer.apply_chat_template(
61
- messages,
62
- add_generation_prompt=True,
63
- tokenize=True,
64
- return_tensors="pt"
65
- ).to(model.device)
66
- logger.info(f"[HANDLER] 模板应用成功,输入 token 数量: {prompt_tokens.shape[-1]}")
67
-
68
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
69
-
70
- generation_kwargs = {
71
- "input_ids": prompt_tokens,
72
- "streamer": streamer,
73
- "max_new_tokens": 1024,
74
- "do_sample": True,
75
- "temperature": 0.7,
76
- "top_p": 0.9
77
- }
78
-
79
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
80
- thread.start()
81
- logger.info("[HANDLER] 生成线程已启动,开始从 streamer 中读取数据...")
82
-
83
- buffer = ""
84
- token_count = 0
85
- for new_text in streamer:
86
- token_count += 1
87
- if "�" in new_text:
88
- continue
89
- logger.debug(f"[STREAM] 正在生成第 {token_count} 个 token: {repr(new_text)}")
90
- buffer += new_text
91
- yield buffer
92
-
93
- logger.info(f"[HANDLER] Streamer 读取完毕,共生成 {token_count} 个 token。")
94
- thread.join()
95
-
96
- except Exception as e:
97
- logger.error(f"[HANDLER] 在推理过程中发生错误: {e}", exc_info=True)
98
- raise gr.Error(f"抱歉,处理您的请求时遇到了一个内部错误: {e}")
99
-
100
- finally:
101
- end_time = time.time()
102
- logger.info(f"--- [END] predict 函数结束,总耗时: {end_time - start_time:.2f} 秒 ---")
103
-
104
- # --- 3. 创建并配置Gradio界面 (已优化) ---
105
 
106
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css="footer {visibility: hidden}") as demo:
107
- gr.Markdown(f"# 你的自定义Qwen模型聊天机器人\n## 模型: `{MODEL_ID}`")
108
 
109
- chat_interface = gr.ChatInterface(
110
- fn=predict,
111
- chatbot=gr.Chatbot(
112
- height=600,
113
- show_copy_button=True,
114
- avatar_images=(None, "https://s2.loli.net/2024/07/17/iPqD3uVgW9eBkbT.png")
115
- ),
116
- title="Qwen 大模型聊天室",
117
- description="向你的微调Qwen模型提问吧!这是一个流式输出的例子。",
118
- examples=[
119
- ["你好,你是谁?"],
120
- ["用 Python 写一个快速排序算法。"],
121
- ["解释一下什么是大型语言模型(LLM)。"]
122
- ],
123
- submit_btn="发送",
124
- )
125
 
126
  if __name__ == "__main__":
127
- logger.info("准备启动 Gradio 应用...")
128
- # .queue() 对于处理多个并发用户至关重要
129
- # 在Hugging Face Spaces上部署时,share=True 不是必需的,但有助于本地测试
130
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
131
- logger.info("Gradio 应用已启动。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
3
  import os
4
 
5
+ # --- 1. 配置与模型加载 ---
6
 
7
+ # 从环境变量或默认值加载模型ID
 
 
 
 
8
  MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
9
+ print(f"正在加载模型: {MODEL_ID}")
10
+
11
+ # 加载分词器和模型
12
+ # trust_remote_code=True 是加载Qwen等模型所必需的
13
+ # device_map="auto" 会自动将模型分配到可用的硬件上(如GPU)
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_ID,
17
+ torch_dtype="auto",
18
+ device_map="auto",
19
+ trust_remote_code=True
20
+ )
21
+ print("模型加载成功!")
22
+
23
+
24
+ # --- 2. 核心推理函数 (API) ---
25
+
26
+ def get_response(prompt: str, history: list[list[str]] = None):
27
+ """
28
+ 一个简单的函数,用于与模型进行单次对话。
29
+
30
+ Args:
31
+ prompt (str): 用户当前输入的问题。
32
+ history (list[list[str]], optional): 对话历史,格式为 [[user_msg_1, bot_msg_1], ...]。默认为 None。
33
+
34
+ Returns:
35
+ str: 模型生成的回复。
36
+ """
37
+ if history is None:
38
+ history = []
39
+
40
+ # 1. 构建消息列表
41
  messages = []
42
+ for user_message, bot_message in history:
 
43
  messages.append({"role": "user", "content": user_message})
44
  messages.append({"role": "assistant", "content": bot_message})
45
+ messages.append({"role": "user", "content": prompt})
46
+
47
+ # 2. 应用聊天模板并进行分词
48
+ # 这是与聊天模型正确交互的关键步骤
49
+ input_ids = tokenizer.apply_chat_template(
50
+ messages,
51
+ add_generation_prompt=True,
52
+ tokenize=True,
53
+ return_tensors="pt"
54
+ ).to(model.device)
55
+
56
+ # 3. 生成回复
57
+ # 这是一个阻塞式调用,会等待模型生成完毕
58
+ outputs = model.generate(
59
+ input_ids,
60
+ max_new_tokens=1024,
61
+ do_sample=True,
62
+ temperature=0.7,
63
+ top_p=0.9
64
+ )
65
+
66
+ # 4. 解码生成的文本
67
+ # `outputs[0]` 包含了输入的token和新生成的token,我们需要切片只获取新生成的部分
68
+ response_ids = outputs[0][input_ids.shape[-1]:]
69
+ response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ return response_text
 
72
 
73
+ # --- 3. 使用示例 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
+ # 示例1: 单轮对话
77
+ print("\n--- 示例 1: 单轮对话 ---")
78
+ question1 = "你好,你是谁?"
79
+ print(f"用户: {question1}")
80
+ answer1 = get_response(question1)
81
+ print(f"模型: {answer1}")
82
+
83
+ # 示例2: 多轮对话
84
+ print("\n--- 示例 2: 多轮对话 ---")
85
+ # 首先,定义一个对话历史
86
+ chat_history = [
87
+ ["用Python写一个快速排序", "当然,这是快速排序的Python实现:\n```python\ndef quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)\n\nprint(quick_sort())\n```"]
88
+ ]
89
+ question2 = "很好,你能解释一下它的工作原理吗?"
90
+ print(f"历史: {chat_history}")
91
+ print(f"用户: {question2}")
92
+ # 调用时传入历史记录
93
+ answer2 = get_response(question2, history=chat_history)
94
+ print(f"模型: {answer2}")