han145 commited on
Commit
24ea806
·
verified ·
1 Parent(s): 4a5c42e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -96
app.py CHANGED
@@ -1,125 +1,217 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
- import json
 
5
 
6
- # 全局变量,避免重复加载
 
 
 
 
7
  model = None
8
  tokenizer = None
 
9
 
10
- def load_model():
11
- """加载模型和分词器"""
12
- global model, tokenizer
 
 
 
 
 
13
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_name,
18
- torch_dtype=torch.float16,
19
- device_map="auto",
20
- low_cpu_mem_usage=True
21
- )
22
- # 确保tokenizer有pad_token
23
- if tokenizer.pad_token is None:
24
- tokenizer.pad_token = tokenizer.eos_token
25
- print("模型加载成功!")
26
  except Exception as e:
27
- print(f"模型加载失败: {e}")
 
28
 
29
- def predict_api(message):
30
- """API专用预测函数"""
31
- if model is None:
32
- load_model()
 
 
 
 
 
33
 
34
- # 构建对话提示
35
- prompt = f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
 
 
36
 
37
- # 编码输入
38
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
 
 
 
 
39
 
40
- # 生成回复
41
- with torch.no_grad():
42
- outputs = model.generate(
43
- **inputs,
44
- max_new_tokens=512,
45
- temperature=0.7,
46
- top_p=0.9,
47
- do_sample=True,
48
- pad_token_id=tokenizer.eos_token_id,
49
- eos_token_id=tokenizer.eos_token_id,
50
- repetition_penalty=1.1
51
  )
 
52
 
53
- # 解码回复
54
- response = tokenizer.decode(outputs[0], skip_special_tokens=False)
55
-
56
- # 提取助理的回复部分
57
- if "<|im_start|>assistant" in response:
58
- # 找到最后一个assistant标记开始的位置
59
- assistant_start = response.rfind("<|im_start|>assistant")
60
- if assistant_start != -1:
61
- assistant_content = response[assistant_start:]
62
- # 提取assistant标记后的内容
63
- if "\n" in assistant_content:
64
- content_start = assistant_content.find("\n") + 1
65
- generated_text = assistant_content[content_start:].split("<|im_end|>")[0].strip()
66
- else:
67
- generated_text = assistant_content.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
68
- else:
69
- generated_text = "抱歉,我无法生成合适的回复。"
70
- else:
71
- # 如果找不到标记,返回整个响应(去除提示部分)
72
- generated_text = response.replace(prompt, "").strip()
73
-
74
- # 返回OpenAI兼容格式
75
- return {
76
- "choices": [{
77
- "message": {
78
- "role": "assistant",
79
- "content": generated_text
80
- }
81
- }]
82
- }
83
-
84
- # 创建Gradio界面
85
- with gr.Blocks() as demo:
86
- gr.Markdown("# DeepSeek-R1 API 服务")
87
-
88
- # 聊天界面
89
- chatbot = gr.Chatbot(label="DeepSeek-R1")
90
- msg = gr.Textbox(label="输入消息")
91
- clear = gr.Button("清除")
92
 
 
93
  def respond(message, chat_history):
94
- """处理聊天请求"""
95
- # 调用预测函数
96
- response = predict_api(message)
97
- # 提取内容
98
- bot_message = response["choices"][0]["message"]["content"]
99
- # 更新聊天历史
100
- chat_history.append((message, bot_message))
101
  return "", chat_history
102
 
103
- # 设置界面交互
104
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
105
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
106
 
107
- # 添加API端点
108
- gr.Interface(
109
- fn=predict_api,
110
- inputs=gr.Textbox(label="输入消息", lines=2),
111
- outputs=gr.JSON(label="API响应"),
112
- title="OpenAI兼容API",
113
- description="使用此端点进行API调用",
114
- api_name="predict"
115
  )
116
 
117
- # 预加载模型(可选)
118
- # load_model()
 
 
 
119
 
120
  if __name__ == "__main__":
 
121
  demo.launch(
122
  server_name="0.0.0.0",
123
  server_port=7860,
124
- share=False
 
 
 
 
125
  )
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import logging
5
+ import time
6
 
7
+ # 配置日志
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # 全局变量
12
  model = None
13
  tokenizer = None
14
+ last_load_time = 0
15
 
16
+ def safe_load_model():
17
+ """安全加载模型,带错误重试机制"""
18
+ global model, tokenizer, last_load_time
19
+
20
+ # 避免频繁重载模型
21
+ if model is not None and time.time() - last_load_time < 300: # 5分钟内不重载
22
+ return True
23
+
24
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
25
+ max_retries = 3
26
+
27
+ for attempt in range(max_retries):
28
+ try:
29
+ logger.info(f"尝试加载模型,第{attempt + 1}次...")
30
+
31
+ # 清理GPU缓存(如果有)
32
+ if torch.cuda.is_available():
33
+ torch.cuda.empty_cache()
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ model_name,
37
+ trust_remote_code=True,
38
+ resume_download=True # 支持断点续传
39
+ )
40
+
41
+ # 确保tokenizer有pad_token
42
+ if tokenizer.pad_token is None:
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ torch_dtype=torch.float16,
48
+ device_map="auto",
49
+ low_cpu_mem_usage=True,
50
+ trust_remote_code=True,
51
+ resume_download=True
52
+ )
53
+
54
+ last_load_time = time.time()
55
+ logger.info("模型加载成功!")
56
+ return True
57
+
58
+ except Exception as e:
59
+ logger.error(f"模型加载失败(尝试{attempt + 1}/{max_retries}): {e}")
60
+ if attempt < max_retries - 1:
61
+ time.sleep(5) # 等待5秒后重试
62
+ else:
63
+ return False
64
+
65
+ def generate_response_safe(message, max_retries=2):
66
+ """安全的响应生成函数,带重试机制"""
67
+ for attempt in range(max_retries):
68
+ try:
69
+ if not safe_load_model():
70
+ return "模型加载失败,请稍后重试"
71
+
72
+ # 限制输入长度,避免内存溢出
73
+ if len(message) > 2000:
74
+ message = message[:2000] + "...(内容过长已截断)"
75
+
76
+ # 构建提示词
77
+ prompt = f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
78
+
79
+ # 编码输入,限制最大长度
80
+ inputs = tokenizer(
81
+ prompt,
82
+ return_tensors="pt",
83
+ truncation=True,
84
+ max_length=1024
85
+ )
86
+
87
+ # 生成回复,限制生成长度
88
+ with torch.no_grad():
89
+ outputs = model.generate(
90
+ **inputs,
91
+ max_new_tokens=256, # 减少生成长度
92
+ temperature=0.7,
93
+ top_p=0.9,
94
+ do_sample=True,
95
+ pad_token_id=tokenizer.eos_token_id,
96
+ eos_token_id=tokenizer.eos_token_id,
97
+ repetition_penalty=1.1
98
+ )
99
+
100
+ # 解码回复
101
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
102
+
103
+ # 提取助理回复
104
+ if "<|im_start|>assistant" in response:
105
+ assistant_start = response.rfind("<|im_start|>assistant")
106
+ if assistant_start != -1:
107
+ assistant_content = response[assistant_start:]
108
+ if "\n" in assistant_content:
109
+ content_start = assistant_content.find("\n") + 1
110
+ generated_text = assistant_content[content_start:].split("<|im_end|>")[0].strip()
111
+ else:
112
+ generated_text = assistant_content.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
113
+ else:
114
+ generated_text = "抱歉,我无法生成合适的回复。"
115
+ else:
116
+ generated_text = response.replace(prompt, "").strip()
117
+
118
+ # 清理缓存
119
+ if torch.cuda.is_available():
120
+ torch.cuda.empty_cache()
121
+
122
+ return generated_text
123
+
124
+ except Exception as e:
125
+ logger.error(f"生成响应失败(尝试{attempt + 1}/{max_retries}): {e}")
126
+ if attempt < max_retries - 1:
127
+ time.sleep(2)
128
+ else:
129
+ return f"生成失败: {str(e)}"
130
+
131
+ def process_chat(message, history):
132
+ """处理聊天请求"""
133
  try:
134
+ response = generate_response_safe(message)
135
+ return response
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
+ logger.error(f"聊天处理异常: {e}")
138
+ return "抱歉,处理您的请求时出现了问题,请稍后重试。"
139
 
140
+ # 创建简化版的Gradio界面
141
+ with gr.Blocks(
142
+ theme=gr.themes.Soft(),
143
+ css="""
144
+ .gradio-container {
145
+ max-width: 800px !important;
146
+ }
147
+ """
148
+ ) as demo:
149
 
150
+ gr.Markdown("""
151
+ # DeepSeek-R1 聊天助手
152
+ *基于DeepSeek-R1-Distill-Qwen-1.5B模型*
153
+ """)
154
 
155
+ # 简化聊天组件
156
+ chatbot = gr.Chatbot(
157
+ label="对话历史",
158
+ height=400,
159
+ show_copy_button=True
160
+ )
161
 
162
+ with gr.Row():
163
+ msg = gr.Textbox(
164
+ label="输入消息",
165
+ placeholder="请输入您的问题...",
166
+ scale=4,
167
+ max_lines=3
 
 
 
 
 
168
  )
169
+ submit_btn = gr.Button("发送", variant="primary", scale=1)
170
 
171
+ clear_btn = gr.Button("清除对话")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # 处理函数
174
  def respond(message, chat_history):
175
+ if not message.strip():
176
+ return "", chat_history
177
+
178
+ chat_history.append([message, ""])
179
+ response = process_chat(message, chat_history)
180
+ chat_history[-1][1] = response
181
+
182
  return "", chat_history
183
 
184
+ # 事件绑定
185
+ msg_submit = msg.submit(
186
+ respond, [msg, chatbot], [msg, chatbot],
187
+ queue=True,
188
+ show_progress="hidden"
189
+ )
190
+
191
+ btn_click = submit_btn.click(
192
+ respond, [msg, chatbot], [msg, chatbot],
193
+ queue=True,
194
+ show_progress="hidden"
195
+ )
196
 
197
+ clear_btn.click(
198
+ lambda: None, None, chatbot, queue=False
 
 
 
 
 
 
199
  )
200
 
201
+ # 预加载模型(可选)
202
+ try:
203
+ safe_load_model()
204
+ except Exception as e:
205
+ logger.warning(f"预加载模型失败: {e}")
206
 
207
  if __name__ == "__main__":
208
+ # 优化启动配置
209
  demo.launch(
210
  server_name="0.0.0.0",
211
  server_port=7860,
212
+ share=False,
213
+ show_error=True,
214
+ debug=False, # 关闭调试模式减少输出
215
+ max_threads=2, # 限制线程数
216
+ quiet=True # 减少日志输出
217
  )