ray-lei commited on
Commit
31e27ad
·
verified ·
1 Parent(s): f68e026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -58
app.py CHANGED
@@ -11,7 +11,7 @@ from fastapi import FastAPI, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel
13
  import torch
14
- from transformers import AutoTokenizer, AutoModelForCausalLM
15
  import json
16
 
17
  # 初始化FastAPI应用
@@ -29,6 +29,7 @@ app.add_middleware(
29
  # 全局变量
30
  model = None
31
  tokenizer = None
 
32
 
33
  # Pydantic模型定义
34
  class Message(BaseModel):
@@ -73,82 +74,158 @@ class ModelListResponse(BaseModel):
73
 
74
  def load_model():
75
  """加载Qwen Coder模型"""
76
- global model, tokenizer
77
 
78
- model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
 
 
 
 
 
79
 
80
- print("Loading tokenizer...")
81
- tokenizer = AutoTokenizer.from_pretrained(
82
- model_name,
83
- trust_remote_code=True
84
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- print("Loading model...")
87
- model = AutoModelForCausalLM.from_pretrained(
88
- model_name,
89
- torch_dtype=torch.float16,
90
- device_map="auto",
91
- trust_remote_code=True,
92
- low_cpu_mem_usage=True
93
- )
 
 
 
 
 
94
 
95
- print("Model loaded successfully!")
 
96
 
97
  def format_messages(messages: List[Message]) -> str:
98
  """将消息格式化为Qwen格式"""
99
- formatted_messages = []
100
- for msg in messages:
101
- formatted_messages.append({
102
- "role": msg.role,
103
- "content": msg.content
104
- })
105
-
106
- # 使用tokenizer的chat template
107
- text = tokenizer.apply_chat_template(
108
- formatted_messages,
109
- tokenize=False,
110
- add_generation_prompt=True
111
- )
112
- return text
 
 
 
 
 
 
 
 
 
113
 
114
  def generate_response(prompt: str, temperature: float, max_tokens: int, top_p: float) -> str:
115
  """生成模型响应"""
116
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
117
-
118
- with torch.no_grad():
119
- outputs = model.generate(
120
- **inputs,
121
- max_new_tokens=max_tokens,
122
- temperature=temperature,
123
- top_p=top_p,
124
- do_sample=True,
125
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
127
-
128
- # 只返回新生成的部分
129
- response = tokenizer.decode(
130
- outputs[0][inputs['input_ids'].shape[1]:],
131
- skip_special_tokens=True
132
- )
133
-
134
- return response.strip()
135
 
136
  @app.on_event("startup")
137
  async def startup_event():
138
  """应用启动时加载模型"""
139
- load_model()
 
 
 
 
140
 
141
  @app.get("/")
142
  async def root():
143
- return {"message": "Qwen Coder API Server is running!"}
 
 
 
 
144
 
145
  @app.get("/v1/models", response_model=ModelListResponse)
146
  async def list_models():
147
  """列出可用模型"""
 
 
148
  return ModelListResponse(
149
  data=[
150
  Model(
151
- id="qwen2.5-coder-7b-instruct",
152
  created=int(time.time()),
153
  owned_by="qwen"
154
  )
@@ -160,7 +237,11 @@ async def chat_completions(request: ChatCompletionRequest):
160
  """处理聊天补全请求"""
161
  try:
162
  if model is None or tokenizer is None:
163
- raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
 
164
 
165
  # 格式化消息
166
  prompt = format_messages(request.messages)
@@ -177,8 +258,13 @@ async def chat_completions(request: ChatCompletionRequest):
177
  completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
178
 
179
  # 计算token使用量(简化版本)
180
- prompt_tokens = len(tokenizer.encode(prompt))
181
- completion_tokens = len(tokenizer.encode(response_text))
 
 
 
 
 
182
 
183
  response = ChatCompletionResponse(
184
  id=completion_id,
@@ -202,14 +288,18 @@ async def chat_completions(request: ChatCompletionRequest):
202
 
203
  except Exception as e:
204
  print(f"Error processing request: {str(e)}")
205
- raise HTTPException(status_code=500, detail=str(e))
206
 
207
  @app.get("/health")
208
  async def health_check():
209
  """健康检查端点"""
210
  return {
211
- "status": "healthy",
212
- "model_loaded": model is not None and tokenizer is not None
 
 
 
 
213
  }
214
 
215
  if __name__ == "__main__":
 
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel
13
  import torch
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
15
  import json
16
 
17
  # 初始化FastAPI应用
 
29
  # 全局变量
30
  model = None
31
  tokenizer = None
32
+ model_name = None
33
 
34
  # Pydantic模型定义
35
  class Message(BaseModel):
 
74
 
75
  def load_model():
76
  """加载Qwen Coder模型"""
77
+ global model, tokenizer, model_name
78
 
79
+ # 模型选择优先级列表
80
+ model_candidates = [
81
+ "Qwen/Qwen2.5-Coder-7B-Instruct",
82
+ "Qwen/Qwen2.5-Coder-3B-Instruct",
83
+ "Qwen/Qwen2.5-Coder-1.5B-Instruct"
84
+ ]
85
 
86
+ for candidate_model in model_candidates:
87
+ try:
88
+ print(f"Attempting to load model: {candidate_model}")
89
+
90
+ # 先测试tokenizer加载
91
+ print("Loading tokenizer...")
92
+ test_tokenizer = AutoTokenizer.from_pretrained(
93
+ candidate_model,
94
+ trust_remote_code=True,
95
+ use_fast=False,
96
+ revision="main"
97
+ )
98
+
99
+ # 如果tokenizer加载成功,继续加载模型
100
+ print("Loading model...")
101
+ test_model = AutoModelForCausalLM.from_pretrained(
102
+ candidate_model,
103
+ torch_dtype=torch.float16,
104
+ device_map="auto",
105
+ trust_remote_code=True,
106
+ low_cpu_mem_usage=True,
107
+ revision="main"
108
+ )
109
+
110
+ # 成功加载后赋值给全局变量
111
+ tokenizer = test_tokenizer
112
+ model = test_model
113
+ model_name = candidate_model
114
+
115
+ print(f"Successfully loaded model: {candidate_model}")
116
+ return
117
+
118
+ except Exception as e:
119
+ print(f"Failed to load {candidate_model}: {str(e)}")
120
+ continue
121
 
122
+ # 如果所有模型都失败,抛出异常
123
+ raise Exception("Failed to load any Qwen model. Please check your configuration.")
124
+
125
+ def format_messages_simple(messages: List[Message]) -> str:
126
+ """简单的消息格式化(备用方案)"""
127
+ formatted = ""
128
+ for msg in messages:
129
+ if msg.role == "system":
130
+ formatted += f"System: {msg.content}\n\n"
131
+ elif msg.role == "user":
132
+ formatted += f"User: {msg.content}\n\n"
133
+ elif msg.role == "assistant":
134
+ formatted += f"Assistant: {msg.content}\n\n"
135
 
136
+ formatted += "Assistant: "
137
+ return formatted
138
 
139
  def format_messages(messages: List[Message]) -> str:
140
  """将消息格式化为Qwen格式"""
141
+ try:
142
+ formatted_messages = []
143
+ for msg in messages:
144
+ formatted_messages.append({
145
+ "role": msg.role,
146
+ "content": msg.content
147
+ })
148
+
149
+ # 尝试使用tokenizer的chat template
150
+ if hasattr(tokenizer, 'apply_chat_template'):
151
+ text = tokenizer.apply_chat_template(
152
+ formatted_messages,
153
+ tokenize=False,
154
+ add_generation_prompt=True
155
+ )
156
+ return text
157
+ else:
158
+ # 如果没有chat_template,使用简单格式化
159
+ return format_messages_simple(messages)
160
+
161
+ except Exception as e:
162
+ print(f"Error in format_messages, using simple format: {str(e)}")
163
+ return format_messages_simple(messages)
164
 
165
  def generate_response(prompt: str, temperature: float, max_tokens: int, top_p: float) -> str:
166
  """生成模型响应"""
167
+ try:
168
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
169
+
170
+ # 移动到模型设备
171
+ if hasattr(model, 'device'):
172
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
173
+
174
+ with torch.no_grad():
175
+ # 设置生成参数
176
+ generation_config = GenerationConfig(
177
+ max_new_tokens=min(max_tokens, 2048),
178
+ temperature=temperature,
179
+ top_p=top_p,
180
+ do_sample=True,
181
+ pad_token_id=tokenizer.eos_token_id,
182
+ eos_token_id=tokenizer.eos_token_id,
183
+ repetition_penalty=1.1
184
+ )
185
+
186
+ outputs = model.generate(
187
+ **inputs,
188
+ generation_config=generation_config
189
+ )
190
+
191
+ # 只返回新生成的部分
192
+ response = tokenizer.decode(
193
+ outputs[0][inputs['input_ids'].shape[1]:],
194
+ skip_special_tokens=True
195
  )
196
+
197
+ return response.strip()
198
+
199
+ except Exception as e:
200
+ print(f"Error in generate_response: {str(e)}")
201
+ return f"抱歉,生成响应时出现错误: {str(e)}"
 
 
202
 
203
  @app.on_event("startup")
204
  async def startup_event():
205
  """应用启动时加载模型"""
206
+ try:
207
+ load_model()
208
+ except Exception as e:
209
+ print(f"Failed to load model during startup: {str(e)}")
210
+ # 不要让启动失败,而是在请求时返回错误
211
 
212
  @app.get("/")
213
  async def root():
214
+ return {
215
+ "message": "Qwen Coder API Server is running!",
216
+ "model_loaded": model is not None,
217
+ "current_model": model_name
218
+ }
219
 
220
  @app.get("/v1/models", response_model=ModelListResponse)
221
  async def list_models():
222
  """列出可用模型"""
223
+ model_id = "qwen2.5-coder-7b-instruct" if model_name is None else model_name.split("/")[-1].lower()
224
+
225
  return ModelListResponse(
226
  data=[
227
  Model(
228
+ id=model_id,
229
  created=int(time.time()),
230
  owned_by="qwen"
231
  )
 
237
  """处理聊天补全请求"""
238
  try:
239
  if model is None or tokenizer is None:
240
+ # 尝试重新加载模型
241
+ try:
242
+ load_model()
243
+ except:
244
+ raise HTTPException(status_code=503, detail="Model not loaded and failed to load on demand")
245
 
246
  # 格式化消息
247
  prompt = format_messages(request.messages)
 
258
  completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
259
 
260
  # 计算token使用量(简化版本)
261
+ try:
262
+ prompt_tokens = len(tokenizer.encode(prompt))
263
+ completion_tokens = len(tokenizer.encode(response_text))
264
+ except:
265
+ # 如果tokenizer编码失败,使用估算
266
+ prompt_tokens = len(prompt.split()) * 2
267
+ completion_tokens = len(response_text.split()) * 2
268
 
269
  response = ChatCompletionResponse(
270
  id=completion_id,
 
288
 
289
  except Exception as e:
290
  print(f"Error processing request: {str(e)}")
291
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
292
 
293
  @app.get("/health")
294
  async def health_check():
295
  """健康检查端点"""
296
  return {
297
+ "status": "healthy" if model is not None and tokenizer is not None else "unhealthy",
298
+ "model_loaded": model is not None and tokenizer is not None,
299
+ "current_model": model_name,
300
+ "torch_version": torch.__version__,
301
+ "cuda_available": torch.cuda.is_available(),
302
+ "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
303
  }
304
 
305
  if __name__ == "__main__":