han145 commited on
Commit
e00d22a
·
verified ·
1 Parent(s): 9094d08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -36
app.py CHANGED
@@ -88,14 +88,19 @@ def load_model():
88
  logger.error(f"模型加载失败: {e}")
89
  return False
90
 
91
- def generate_response(prompt):
92
- """极简响应生成 - 仅使用用户输入"""
93
  if model is None or tokenizer is None:
94
  return {"error": "模型未加载"}
95
 
96
  try:
97
- # 手动构建提示
98
- text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
 
 
 
 
 
99
 
100
  # 编码输入
101
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
@@ -105,15 +110,25 @@ def generate_response(prompt):
105
  with torch.no_grad():
106
  outputs = model.generate(
107
  **inputs,
108
- max_new_tokens=MAX_TOKENS,
109
  do_sample=True,
110
- temperature=0.7,
111
  top_p=0.9,
112
  pad_token_id=tokenizer.eos_token_id
113
  )
114
 
115
- # 解码响应
116
- response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
117
 
118
  # 立即清理内存
119
  del inputs, outputs
@@ -121,7 +136,7 @@ def generate_response(prompt):
121
  torch.cuda.empty_cache()
122
  gc.collect()
123
 
124
- return {"content": response.strip()}
125
 
126
  except Exception as e:
127
  logger.error(f"生成响应失败: {e}")
@@ -129,9 +144,9 @@ def generate_response(prompt):
129
 
130
  # 创建极简FastAPI应用
131
  app = FastAPI(
132
- title="OpenClaw专用API",
133
  version="1.0",
134
- description="专为OpenClaw优化的API服务"
135
  )
136
 
137
  # 启动时加载模型
@@ -142,19 +157,118 @@ async def startup_event():
142
  if API_AUTH_ENABLED:
143
  logger.info(f"有效的API密钥数量: {len(API_KEYS)}")
144
 
145
- # OpenClaw专用端点
146
- @app.post("/chat/completions")
147
- async def openclaw_chat_api(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  request: Request,
149
  auth_valid: bool = Depends(verify_api_key)
150
  ):
151
- """专为OpenClaw设计的API端点"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  try:
153
  # 解析请求
154
  data = await request.json()
155
  messages = data.get("messages", [])
 
 
 
156
 
157
- # 提取用户消息
158
  user_message = ""
159
  for msg in messages:
160
  if msg.get("role") == "user":
@@ -165,49 +279,62 @@ async def openclaw_chat_api(
165
  return JSONResponse(
166
  status_code=400,
167
  content={
168
- "error": "未找到用户消息"
 
 
 
169
  }
170
  )
171
 
172
  # 生成响应
173
- result = generate_response(user_message)
174
 
175
  if "error" in result:
176
  return JSONResponse(
177
  status_code=500,
178
  content={
179
- "error": result["error"]
 
 
 
180
  }
181
  )
182
 
183
- # 返回OpenClaw专用格式
184
  return {
185
- "choices": [{
186
- "message": {
187
- "content": result["content"]
 
 
 
 
 
 
 
 
 
188
  }
189
- }]
 
 
 
 
 
190
  }
191
 
192
  except Exception as e:
193
- logger.error(f"OpenClaw API错误: {e}")
194
  return JSONResponse(
195
  status_code=500,
196
  content={
197
- "error": f"内部服务器错误: {str(e)}"
 
 
 
198
  }
199
  )
200
 
201
- # 健康检查端点
202
- @app.get("/health")
203
- async def health_check():
204
- return {
205
- "status": "healthy" if model is not None else "loading",
206
- "model_loaded": model is not None,
207
- "api_auth_enabled": API_AUTH_ENABLED,
208
- "timestamp": int(time.time())
209
- }
210
-
211
  if __name__ == "__main__":
212
  import uvicorn
213
 
 
88
  logger.error(f"模型加载失败: {e}")
89
  return False
90
 
91
+ def generate_completion(prompt, max_tokens=256, temperature=0.7):
92
+ """生成OpenAI Completions格式的响应"""
93
  if model is None or tokenizer is None:
94
  return {"error": "模型未加载"}
95
 
96
  try:
97
+ # 构建提示词 - 使用Qwen模型的对话格式
98
+ if "user" in prompt.lower() or "assistant" in prompt.lower():
99
+ # 如果提示词已经包含对话格式,直接使用
100
+ text = prompt
101
+ else:
102
+ # 否则,将提示词包装为对话格式
103
+ text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
104
 
105
  # 编码输入
106
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
 
110
  with torch.no_grad():
111
  outputs = model.generate(
112
  **inputs,
113
+ max_new_tokens=min(max_tokens, MAX_TOKENS),
114
  do_sample=True,
115
+ temperature=temperature,
116
  top_p=0.9,
117
  pad_token_id=tokenizer.eos_token_id
118
  )
119
 
120
+ # 解码完整响应(包括提示词和生成内容)
121
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
122
+
123
+ # 提取生成的文本(去除提示词部分)
124
+ if text in full_response:
125
+ generated_text = full_response[len(text):]
126
+ else:
127
+ # 如果提取失败,使用简单方法
128
+ generated_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
129
+
130
+ # 清理特殊标记
131
+ generated_text = generated_text.replace("<|im_end|>", "").strip()
132
 
133
  # 立即清理内存
134
  del inputs, outputs
 
136
  torch.cuda.empty_cache()
137
  gc.collect()
138
 
139
+ return {"text": generated_text}
140
 
141
  except Exception as e:
142
  logger.error(f"生成响应失败: {e}")
 
144
 
145
  # 创建极简FastAPI应用
146
  app = FastAPI(
147
+ title="OpenAI Completions API兼容服务",
148
  version="1.0",
149
+ description="专为OpenClaw优化的OpenAI Completions API兼容服务"
150
  )
151
 
152
  # 启动时加载模型
 
157
  if API_AUTH_ENABLED:
158
  logger.info(f"有效的API密钥数量: {len(API_KEYS)}")
159
 
160
+ # 健康检查端点(无需认证)
161
+ @app.get("/health")
162
+ async def health_check():
163
+ return {
164
+ "status": "healthy" if model is not None else "loading",
165
+ "model_loaded": model is not None,
166
+ "api_auth_enabled": API_AUTH_ENABLED,
167
+ "timestamp": int(time.time())
168
+ }
169
+
170
+ # 根端点(无需认证)
171
+ @app.get("/")
172
+ async def root():
173
+ return {
174
+ "message": "OpenAI Completions API兼容服务运行中",
175
+ "model_loaded": model is not None,
176
+ "api_auth_enabled": API_AUTH_ENABLED,
177
+ "endpoints": {
178
+ "completions": "/v1/completions",
179
+ "chat_completions": "/v1/chat/completions"
180
+ }
181
+ }
182
+
183
+ # OpenAI Completions端点(OpenClaw主要使用这个)
184
+ @app.post("/v1/completions")
185
+ async def create_completion(
186
  request: Request,
187
  auth_valid: bool = Depends(verify_api_key)
188
  ):
189
+ """OpenAI Completions API兼容端点"""
190
+ try:
191
+ # 解析请求
192
+ data = await request.json()
193
+ prompt = data.get("prompt", "")
194
+ model_name = data.get("model", "qwen1.5-0.5b-chat")
195
+ max_tokens = data.get("max_tokens", MAX_TOKENS)
196
+ temperature = data.get("temperature", 0.7)
197
+
198
+ if not prompt:
199
+ return JSONResponse(
200
+ status_code=400,
201
+ content={
202
+ "error": {
203
+ "message": "缺少必需的参数: prompt",
204
+ "type": "invalid_request_error"
205
+ }
206
+ }
207
+ )
208
+
209
+ # 生成响应
210
+ result = generate_completion(prompt, max_tokens, temperature)
211
+
212
+ if "error" in result:
213
+ return JSONResponse(
214
+ status_code=500,
215
+ content={
216
+ "error": {
217
+ "message": result["error"],
218
+ "type": "internal_error"
219
+ }
220
+ }
221
+ )
222
+
223
+ # 返回OpenAI Completions兼容格式
224
+ return {
225
+ "id": f"cmpl-{int(time.time())}",
226
+ "object": "text_completion",
227
+ "created": int(time.time()),
228
+ "model": model_name,
229
+ "choices": [
230
+ {
231
+ "text": result["text"],
232
+ "index": 0,
233
+ "logprobs": None,
234
+ "finish_reason": "stop"
235
+ }
236
+ ],
237
+ "usage": {
238
+ "prompt_tokens": len(tokenizer.encode(prompt)) if tokenizer else 0,
239
+ "completion_tokens": len(tokenizer.encode(result["text"])) if tokenizer else 0,
240
+ "total_tokens": len(tokenizer.encode(prompt)) + len(tokenizer.encode(result["text"])) if tokenizer else 0
241
+ }
242
+ }
243
+
244
+ except Exception as e:
245
+ logger.error(f"Completions API错误: {e}")
246
+ return JSONResponse(
247
+ status_code=500,
248
+ content={
249
+ "error": {
250
+ "message": f"内部服务器错误: {str(e)}",
251
+ "type": "internal_error"
252
+ }
253
+ }
254
+ )
255
+
256
+ # 保持Chat Completions端点兼容性
257
+ @app.post("/v1/chat/completions")
258
+ async def create_chat_completion(
259
+ request: Request,
260
+ auth_valid: bool = Depends(verify_api_key)
261
+ ):
262
+ """OpenAI Chat Completions API兼容端点"""
263
  try:
264
  # 解析请求
265
  data = await request.json()
266
  messages = data.get("messages", [])
267
+ model_name = data.get("model", "qwen1.5-0.5b-chat")
268
+ max_tokens = data.get("max_tokens", MAX_TOKENS)
269
+ temperature = data.get("temperature", 0.7)
270
 
271
+ # 从消息中提取用户提示
272
  user_message = ""
273
  for msg in messages:
274
  if msg.get("role") == "user":
 
279
  return JSONResponse(
280
  status_code=400,
281
  content={
282
+ "error": {
283
+ "message": "未找到用户消息",
284
+ "type": "invalid_request_error"
285
+ }
286
  }
287
  )
288
 
289
  # 生成响应
290
+ result = generate_completion(user_message, max_tokens, temperature)
291
 
292
  if "error" in result:
293
  return JSONResponse(
294
  status_code=500,
295
  content={
296
+ "error": {
297
+ "message": result["error"],
298
+ "type": "internal_error"
299
+ }
300
  }
301
  )
302
 
303
+ # 返回OpenAI Chat Completions兼容格式
304
  return {
305
+ "id": f"chatcmpl-{int(time.time())}",
306
+ "object": "chat.completion",
307
+ "created": int(time.time()),
308
+ "model": model_name,
309
+ "choices": [
310
+ {
311
+ "index": 0,
312
+ "message": {
313
+ "role": "assistant",
314
+ "content": result["text"]
315
+ },
316
+ "finish_reason": "stop"
317
  }
318
+ ],
319
+ "usage": {
320
+ "prompt_tokens": len(tokenizer.encode(user_message)) if tokenizer else 0,
321
+ "completion_tokens": len(tokenizer.encode(result["text"])) if tokenizer else 0,
322
+ "total_tokens": len(tokenizer.encode(user_message)) + len(tokenizer.encode(result["text"])) if tokenizer else 0
323
+ }
324
  }
325
 
326
  except Exception as e:
327
+ logger.error(f"Chat Completions API错误: {e}")
328
  return JSONResponse(
329
  status_code=500,
330
  content={
331
+ "error": {
332
+ "message": f"内部服务器错误: {str(e)}",
333
+ "type": "internal_error"
334
+ }
335
  }
336
  )
337
 
 
 
 
 
 
 
 
 
 
 
338
  if __name__ == "__main__":
339
  import uvicorn
340