han145 commited on
Commit
847d3f0
·
verified ·
1 Parent(s): 69fd688

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -101
app.py CHANGED
@@ -2,34 +2,25 @@ import os
2
  import time
3
  import json
4
  import logging
5
- import gc
6
- from typing import List, Dict, Any, Union
7
-
8
  from fastapi import FastAPI, Request, HTTPException, Depends, status
9
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
10
  from fastapi.responses import JSONResponse
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import torch
 
13
 
14
- # -----------------------------------------------------------------------------
15
- # 日志配置
16
- # -----------------------------------------------------------------------------
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format='%(asctime)s - %(levelname)s - %(message)s'
20
- )
21
  logger = logging.getLogger(__name__)
22
 
23
- # -----------------------------------------------------------------------------
24
- # 全局变量与配置
25
- # -----------------------------------------------------------------------------
26
  model = None
27
  tokenizer = None
28
 
29
  # 配置
30
  MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
31
  MAX_TOKENS = 256
32
- DEVICE = "cpu" # 强制使用CPU,如需GPU请改为 "cuda"
33
 
34
  # API密钥配置
35
  API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
@@ -38,10 +29,6 @@ API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true"
38
  # 创建Bearer认证方案
39
  security = HTTPBearer()
40
 
41
- # -----------------------------------------------------------------------------
42
- # 辅助函数
43
- # -----------------------------------------------------------------------------
44
-
45
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
46
  """验证API密钥"""
47
  if not API_AUTH_ENABLED:
@@ -68,9 +55,6 @@ def load_model():
68
  """极简模型加载"""
69
  global model, tokenizer
70
 
71
- if model is not None:
72
- return True
73
-
74
  try:
75
  logger.info(f"开始加载模型: {MODEL_NAME}")
76
 
@@ -93,7 +77,7 @@ def load_model():
93
  trust_remote_code=True
94
  )
95
 
96
- # 移动到设备
97
  model = model.to(DEVICE)
98
  model.eval() # 设置为评估模式
99
 
@@ -104,67 +88,85 @@ def load_model():
104
  logger.error(f"模型加载失败: {e}")
105
  return False
106
 
107
- def generate_chat_response(messages: List[Dict[str, str]], max_tokens=256, temperature=0.7):
108
  """
109
- 使用 apply_chat_template 生成响应,支持完整对话历史。
110
- 这是 Hugging Face 推荐的标准方式。
111
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  if model is None or tokenizer is None:
113
  return {"error": "模型未加载"}
114
-
115
  try:
116
- # 使用 apply_chat_template 构建输入,它会自动处理 <|im_start|> 等格式
117
- # tokenize=False 返回字符串,方便调试和确保输入类型正确
118
- text = tokenizer.apply_chat_template(
119
- messages,
120
- tokenize=False,
121
- add_generation_prompt=True
122
- )
123
-
124
- # 记录生成的文本提示(用于调试)
125
- # logger.info(f"生成的Prompt片段: {text[:100]}...")
126
 
127
- # 编码输入 - 直接传入字符串,不放入列表,避免某些 tokenizer 版本的批处理歧义
128
- model_inputs = tokenizer(
129
- text,
130
- return_tensors="pt"
131
- ).to(DEVICE)
132
-
133
- # 生成响应
 
 
 
 
 
 
 
134
  with torch.no_grad():
135
- generated_ids = model.generate(
136
- model_inputs.input_ids,
137
  max_new_tokens=min(max_tokens, MAX_TOKENS),
138
  do_sample=True,
139
  temperature=temperature,
140
- top_p=0.9,
141
- pad_token_id=tokenizer.eos_token_id
 
 
142
  )
143
-
144
- # 获取新生成的token(去掉输入的token)
145
- generated_ids = [
146
- output_ids[len(input_ids):]
147
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
148
- ]
149
 
150
- # 解码响应
151
- response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
152
-
153
- # 清理内存
154
- del model_inputs, generated_ids
155
- if torch.cuda.is_available():
156
- torch.cuda.empty_cache()
 
 
157
  gc.collect()
158
-
159
- return {"text": response_text}
160
-
161
  except Exception as e:
162
- logger.error(f"生成响应失败: {e}", exc_info=True)
163
- return {"error": f"生成失败: {str(e)}"}
164
 
165
- # -----------------------------------------------------------------------------
166
- # FastAPI 应用
167
- # -----------------------------------------------------------------------------
168
  app = FastAPI(
169
  title="OpenAI API兼容服务",
170
  version="1.0",
@@ -179,7 +181,7 @@ async def startup_event():
179
  if API_AUTH_ENABLED:
180
  logger.info(f"有效的API密钥数量: {len(API_KEYS)}")
181
 
182
- # 健康检查端点
183
  @app.get("/health")
184
  async def health_check():
185
  return {
@@ -189,6 +191,31 @@ async def health_check():
189
  "timestamp": int(time.time())
190
  }
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  @app.get("/v1/models")
193
  async def list_models():
194
  """返回可用的模型列表"""
@@ -204,42 +231,37 @@ async def list_models():
204
  ]
205
  }
206
 
207
- # OpenAI Chat Completions端点
208
  @app.post("/v1/chat/completions")
209
  async def create_chat_completion(
210
  request: Request,
211
  auth_valid: bool = Depends(verify_api_key)
212
  ):
213
- """OpenAI Chat Completions API兼容端点"""
214
  try:
215
- # 解析请求
216
  data = await request.json()
217
  messages = data.get("messages", [])
218
- model_name = data.get("model", "qwen1.5-0.5b-chat")
219
  max_tokens = data.get("max_tokens", MAX_TOKENS)
220
  temperature = data.get("temperature", 0.7)
221
-
222
- logger.info(f"收到Chat Completions请求: model={model_name}, messages_count={len(messages)}")
223
-
224
- # 检查消息格式
225
  if not messages or not isinstance(messages, list):
226
- return JSONResponse(
227
- status_code=400,
228
- content={"error": {"message": "无效的消息格式", "type": "invalid_request_error"}}
229
- )
230
-
231
- # 使用新的生成函数,直接传递 messages 列表
232
  result = generate_chat_response(messages, max_tokens, temperature)
233
-
234
  if "error" in result:
235
- return JSONResponse(
236
- status_code=500,
237
- content={"error": {"message": result["error"], "type": "internal_error"}}
238
- )
239
-
240
- # 返回OpenAI Chat Completions兼容格式
 
241
  response_data = {
242
- "id": f"chatcmpl-{int(time.time())}",
243
  "object": "chat.completion",
244
  "created": int(time.time()),
245
  "model": model_name,
@@ -254,33 +276,109 @@ async def create_chat_completion(
254
  }
255
  ],
256
  "usage": {
257
- "prompt_tokens": -1,
258
- "completion_tokens": -1,
259
- "total_tokens": -1
260
  }
261
  }
262
-
263
- logger.info(f"成功生成响应: {len(result['text'])} 字符")
264
  return response_data
265
-
266
  except Exception as e:
267
- logger.error(f"Chat Completions API错误: {e}", exc_info=True)
268
  return JSONResponse(
269
  status_code=500,
270
  content={
271
  "error": {
272
- "message": f"内部服务器错误: {str(e)}",
273
- "type": "internal_error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  }
276
  )
277
 
278
  if __name__ == "__main__":
279
  import uvicorn
 
 
280
  uvicorn.run(
281
  app,
282
  host="0.0.0.0",
283
  port=7860,
284
- workers=1,
285
  log_level="info"
286
  )
 
2
  import time
3
  import json
4
  import logging
 
 
 
5
  from fastapi import FastAPI, Request, HTTPException, Depends, status
6
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
  from fastapi.responses import JSONResponse
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  import torch
10
+ import gc
11
 
12
+ # 极简日志配置
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
+ # 全局变量
 
 
17
  model = None
18
  tokenizer = None
19
 
20
  # 配置
21
  MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
22
  MAX_TOKENS = 256
23
+ DEVICE = "cpu" # 强制使用CPU
24
 
25
  # API密钥配置
26
  API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
 
29
  # 创建Bearer认证方案
30
  security = HTTPBearer()
31
 
 
 
 
 
32
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
33
  """验证API密钥"""
34
  if not API_AUTH_ENABLED:
 
55
  """极简模型加载"""
56
  global model, tokenizer
57
 
 
 
 
58
  try:
59
  logger.info(f"开始加载模型: {MODEL_NAME}")
60
 
 
77
  trust_remote_code=True
78
  )
79
 
80
+ # 移动到CPU
81
  model = model.to(DEVICE)
82
  model.eval() # 设置为评估模式
83
 
 
88
  logger.error(f"模型加载失败: {e}")
89
  return False
90
 
91
+ def apply_chat_template(messages):
92
  """
93
+ OpenAI 格式 messages 转为 Qwen 的 chat template 格式
 
94
  """
95
+ text = ""
96
+ for msg in messages:
97
+ role = msg.get("role", "").lower()
98
+ content = msg.get("content", "").strip()
99
+ if not content:
100
+ continue
101
+
102
+ if role == "system":
103
+ text += f"<|im_start|>system\n{content}<|im_end|>\n"
104
+ elif role == "user":
105
+ text += f"<|im_start|>user\n{content}<|im_end|>\n"
106
+ elif role == "assistant":
107
+ text += f"<|im_start|>assistant\n{content}<|im_end|>\n"
108
+ else:
109
+ # 忽略其他 role
110
+ continue
111
+
112
+ # 最后加上 assistant 的开头
113
+ text += "<|im_start|>assistant\n"
114
+ return text
115
+
116
+
117
+ def generate_chat_response(messages, max_tokens=256, temperature=0.7):
118
+ """生成完整对话回复"""
119
  if model is None or tokenizer is None:
120
  return {"error": "模型未加载"}
121
+
122
  try:
123
+ # 转换为 Qwen 的对话格式
124
+ prompt = apply_chat_template(messages)
 
 
 
 
 
 
 
 
125
 
126
+ logger.info(f"输入文本类型: {type(prompt)}, 长度: {len(prompt)}")
127
+ logger.debug(f"完整prompt前100字符: {prompt[:100]}...")
128
+
129
+ # 分词(注意这里使用列表包一层字符串)
130
+ inputs = tokenizer(
131
+ [prompt], # 必须是 list[str]
132
+ return_tensors="pt",
133
+ truncation=True,
134
+ max_length=3072, # Qwen1.5 支持较长上下文
135
+ padding=True
136
+ )
137
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
138
+
139
+ # 生成
140
  with torch.no_grad():
141
+ outputs = model.generate(
142
+ **inputs,
143
  max_new_tokens=min(max_tokens, MAX_TOKENS),
144
  do_sample=True,
145
  temperature=temperature,
146
+ top_p=0.85,
147
+ repetition_penalty=1.05, # 轻微防止重复
148
+ pad_token_id=tokenizer.eos_token_id,
149
+ eos_token_id=tokenizer.eos_token_id,
150
  )
 
 
 
 
 
 
151
 
152
+ # 只取新生成的 token
153
+ generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
154
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True)
155
+
156
+ # 清理可能的结束标记
157
+ response = response.split("<|im_end|>")[0].strip()
158
+
159
+ # 内存清理
160
+ del inputs, outputs
161
  gc.collect()
162
+
163
+ return {"text": response}
164
+
165
  except Exception as e:
166
+ logger.error(f"生成失败: {str(e)}", exc_info=True)
167
+ return {"error": str(e)}
168
 
169
+ # 创建极简FastAPI应用
 
 
170
  app = FastAPI(
171
  title="OpenAI API兼容服务",
172
  version="1.0",
 
181
  if API_AUTH_ENABLED:
182
  logger.info(f"有效的API密钥数量: {len(API_KEYS)}")
183
 
184
+ # 健康检查端点(无需认证)
185
  @app.get("/health")
186
  async def health_check():
187
  return {
 
191
  "timestamp": int(time.time())
192
  }
193
 
194
+ # 根端点(无需认证)
195
+ @app.get("/")
196
+ async def root():
197
+ return {
198
+ "message": "OpenAI API兼容服务运行中",
199
+ "model_loaded": model is not None,
200
+ "api_auth_enabled": API_AUTH_ENABLED,
201
+ "endpoints": {
202
+ "v1": "/v1",
203
+ "chat_completions": "/v1/chat/completions"
204
+ }
205
+ }
206
+
207
+ # 添加/v1端点(OpenClaw可能需要)
208
+ @app.get("/v1")
209
+ async def v1_root():
210
+ return {
211
+ "message": "OpenAI v1 API端点",
212
+ "endpoints": {
213
+ "models": "/v1/models",
214
+ "chat_completions": "/v1/chat/completions"
215
+ }
216
+ }
217
+
218
+ # 添加模型列表端点(OpenAI兼容)
219
  @app.get("/v1/models")
220
  async def list_models():
221
  """返回可用的模型列表"""
 
231
  ]
232
  }
233
 
234
+ # OpenAI Chat Completions端点(主要端点)
235
  @app.post("/v1/chat/completions")
236
  async def create_chat_completion(
237
  request: Request,
238
  auth_valid: bool = Depends(verify_api_key)
239
  ):
 
240
  try:
 
241
  data = await request.json()
242
  messages = data.get("messages", [])
243
+ model_name = data.get("model", MODEL_NAME)
244
  max_tokens = data.get("max_tokens", MAX_TOKENS)
245
  temperature = data.get("temperature", 0.7)
246
+
247
+ logger.info(f"收到请求: model={model_name}, messages_count={len(messages)}")
248
+
 
249
  if not messages or not isinstance(messages, list):
250
+ raise ValueError("messages 必须是非空列表")
251
+
252
+ # 生成回复
 
 
 
253
  result = generate_chat_response(messages, max_tokens, temperature)
254
+
255
  if "error" in result:
256
+ raise RuntimeError(result["error"])
257
+
258
+ # 计算粗略 token 数(仅供参考)
259
+ prompt_text = "".join([m["content"] for m in messages if m.get("content")])
260
+ prompt_tokens = len(tokenizer.encode(prompt_text)) if tokenizer else 0
261
+ completion_tokens = len(tokenizer.encode(result["text"])) if tokenizer else 0
262
+
263
  response_data = {
264
+ "id": f"chatcmpl-{int(time.time()*1000)}",
265
  "object": "chat.completion",
266
  "created": int(time.time()),
267
  "model": model_name,
 
276
  }
277
  ],
278
  "usage": {
279
+ "prompt_tokens": prompt_tokens,
280
+ "completion_tokens": completion_tokens,
281
+ "total_tokens": prompt_tokens + completion_tokens
282
  }
283
  }
284
+
 
285
  return response_data
286
+
287
  except Exception as e:
288
+ logger.error(f"Chat Completions 错误: {str(e)}", exc_info=True)
289
  return JSONResponse(
290
  status_code=500,
291
  content={
292
  "error": {
293
+ "message": str(e),
294
+ "type": "internal_server_error"
295
+ }
296
+ }
297
+ )
298
+
299
+ # 添加兼容性端点(为不同版本的OpenClaw提供支持)
300
+ @app.post("/chat/completions")
301
+ async def legacy_chat_completion(
302
+ request: Request,
303
+ auth_valid: bool = Depends(verify_api_key)
304
+ ):
305
+ """兼容旧版本OpenClaw的端点"""
306
+ # 直接转发到/v1/chat/completions
307
+ return await create_chat_completion(request, auth_valid)
308
+
309
+ # 添加通用聊天端点
310
+ @app.post("/api/chat")
311
+ async def generic_chat_api(
312
+ request: Request,
313
+ auth_valid: bool = Depends(verify_api_key)
314
+ ):
315
+ """通用聊天API端点"""
316
+ try:
317
+ # 解析请求
318
+ data = await request.json()
319
+ messages = data.get("messages", [])
320
+
321
+ # 检查消息格式
322
+ if not messages or not isinstance(messages, list):
323
+ return JSONResponse(
324
+ status_code=400,
325
+ content={
326
+ "error": "无效的消息格式"
327
+ }
328
+ )
329
+
330
+ # 提取用户消息
331
+ user_message = ""
332
+ for msg in messages:
333
+ if isinstance(msg, dict) and msg.get("role") == "user":
334
+ user_message = msg.get("content", "")
335
+ break
336
+
337
+ if not user_message:
338
+ return JSONResponse(
339
+ status_code=400,
340
+ content={
341
+ "error": "未找到用户消息"
342
+ }
343
+ )
344
+
345
+ # 生成响应
346
+ result = generate_completion(user_message)
347
+
348
+ if "error" in result:
349
+ return JSONResponse(
350
+ status_code=500,
351
+ content={
352
+ "error": result["error"]
353
  }
354
+ )
355
+
356
+ # 返回通用格式
357
+ return {
358
+ "choices": [{
359
+ "message": {
360
+ "content": result["text"]
361
+ }
362
+ }]
363
+ }
364
+
365
+ except Exception as e:
366
+ logger.error(f"通用聊天API错误: {e}")
367
+ return JSONResponse(
368
+ status_code=500,
369
+ content={
370
+ "error": f"内部服务器错误: {str(e)}"
371
  }
372
  )
373
 
374
  if __name__ == "__main__":
375
  import uvicorn
376
+
377
+ # 极简UVicorn配置
378
  uvicorn.run(
379
  app,
380
  host="0.0.0.0",
381
  port=7860,
382
+ workers=1, # 单worker减少内存占用
383
  log_level="info"
384
  )