han145 commited on
Commit
85e708c
·
verified ·
1 Parent(s): 844798d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -95
app.py CHANGED
@@ -108,19 +108,14 @@ def generate_completion(prompt, max_tokens=256, temperature=0.7):
108
  # 记录文本类型和长度
109
  logger.info(f"输入文本类型: {type(text)}, 长度: {len(text)}")
110
 
111
- # 使用更安全的编码方式
112
- # 确保输入是字符串类型
113
- input_text = str(text)
114
-
115
- # 编码输入
116
- inputs = tokenizer(
117
- input_text,
118
- return_tensors="pt",
119
- truncation=True,
120
- max_length=1024,
121
- padding=True
122
- )
123
- inputs = inputs.to(DEVICE)
124
 
125
  # 生成响应
126
  with torch.no_grad():
@@ -153,9 +148,9 @@ def generate_completion(prompt, max_tokens=256, temperature=0.7):
153
 
154
  # 创建极简FastAPI应用
155
  app = FastAPI(
156
- title="OpenAI Completions API兼容服务",
157
  version="1.0",
158
- description="专为OpenClaw优化的OpenAI Completions API兼容服务"
159
  )
160
 
161
  # 启动时加载模型
@@ -180,95 +175,49 @@ async def health_check():
180
  @app.get("/")
181
  async def root():
182
  return {
183
- "message": "OpenAI Completions API兼容服务运行中",
184
  "model_loaded": model is not None,
185
  "api_auth_enabled": API_AUTH_ENABLED,
186
  "endpoints": {
187
- "completions": "/v1/completions",
188
  "chat_completions": "/v1/chat/completions"
189
  }
190
  }
191
 
192
- # OpenAI Completions端点(OpenClaw使用这个
193
- @app.post("/v1/completions")
194
- async def create_completion(
195
- request: Request,
196
- auth_valid: bool = Depends(verify_api_key)
197
- ):
198
- """OpenAI Completions API兼容端点"""
199
- try:
200
- # 解析请求
201
- data = await request.json()
202
- prompt = data.get("prompt", "")
203
- model_name = data.get("model", "qwen1.5-0.5b-chat")
204
- max_tokens = data.get("max_tokens", MAX_TOKENS)
205
- temperature = data.get("temperature", 0.7)
206
-
207
- if not prompt:
208
- return JSONResponse(
209
- status_code=400,
210
- content={
211
- "error": {
212
- "message": "缺少必需的参数: prompt",
213
- "type": "invalid_request_error"
214
- }
215
- }
216
- )
217
-
218
- # 生成响应
219
- result = generate_completion(prompt, max_tokens, temperature)
220
-
221
- if "error" in result:
222
- return JSONResponse(
223
- status_code=500,
224
- content={
225
- "error": {
226
- "message": result["error"],
227
- "type": "internal_error"
228
- }
229
- }
230
- )
231
-
232
- # 返回OpenAI Completions兼容格式
233
- return {
234
- "id": f"cmpl-{int(time.time())}",
235
- "object": "text_completion",
236
- "created": int(time.time()),
237
- "model": model_name,
238
- "choices": [
239
- {
240
- "text": result["text"],
241
- "index": 0,
242
- "logprobs": None,
243
- "finish_reason": "stop"
244
- }
245
- ],
246
- "usage": {
247
- "prompt_tokens": len(tokenizer.encode(prompt)) if tokenizer else 0,
248
- "completion_tokens": len(tokenizer.encode(result["text"])) if tokenizer else 0,
249
- "total_tokens": len(tokenizer.encode(prompt)) + len(tokenizer.encode(result["text"])) if tokenizer else 0
250
- }
251
  }
252
-
253
- except Exception as e:
254
- logger.error(f"Completions API错误: {e}")
255
- return JSONResponse(
256
- status_code=500,
257
- content={
258
- "error": {
259
- "message": f"内部服务器错误: {str(e)}",
260
- "type": "internal_error"
261
- }
 
 
 
 
262
  }
263
- )
 
264
 
265
- # OpenAI Chat Completions端点
266
  @app.post("/v1/chat/completions")
267
  async def create_chat_completion(
268
  request: Request,
269
  auth_valid: bool = Depends(verify_api_key)
270
  ):
271
- """OpenAI Chat Completions API兼容端点"""
272
  try:
273
  # 解析请求
274
  data = await request.json()
@@ -277,6 +226,9 @@ async def create_chat_completion(
277
  max_tokens = data.get("max_tokens", MAX_TOKENS)
278
  temperature = data.get("temperature", 0.7)
279
 
 
 
 
280
  # 检查消息格式
281
  if not messages or not isinstance(messages, list):
282
  return JSONResponse(
@@ -322,7 +274,7 @@ async def create_chat_completion(
322
  )
323
 
324
  # 返回OpenAI Chat Completions兼容格式
325
- return {
326
  "id": f"chatcmpl-{int(time.time())}",
327
  "object": "chat.completion",
328
  "created": int(time.time()),
@@ -344,6 +296,9 @@ async def create_chat_completion(
344
  }
345
  }
346
 
 
 
 
347
  except Exception as e:
348
  logger.error(f"Chat Completions API错误: {e}")
349
  return JSONResponse(
@@ -356,13 +311,23 @@ async def create_chat_completion(
356
  }
357
  )
358
 
359
- # 添加OpenClaw专用端点(简化版)
 
 
 
 
 
 
 
 
 
 
360
  @app.post("/api/chat")
361
- async def openclaw_chat_api(
362
  request: Request,
363
  auth_valid: bool = Depends(verify_api_key)
364
  ):
365
- """专为OpenClaw设计的API端点"""
366
  try:
367
  # 解析请求
368
  data = await request.json()
@@ -403,7 +368,7 @@ async def openclaw_chat_api(
403
  }
404
  )
405
 
406
- # 返回OpenClaw专用格式
407
  return {
408
  "choices": [{
409
  "message": {
@@ -413,7 +378,7 @@ async def openclaw_chat_api(
413
  }
414
 
415
  except Exception as e:
416
- logger.error(f"OpenClaw API错误: {e}")
417
  return JSONResponse(
418
  status_code=500,
419
  content={
 
108
  # 记录文本类型和长度
109
  logger.info(f"输入文本类型: {type(text)}, 长度: {len(text)}")
110
 
111
+ # 使用更基础的编码方式
112
+ input_ids = tokenizer.encode(text, truncation=True, max_length=1024, return_tensors="pt")
113
+ attention_mask = torch.ones_like(input_ids)
114
+
115
+ inputs = {
116
+ "input_ids": input_ids.to(DEVICE),
117
+ "attention_mask": attention_mask.to(DEVICE)
118
+ }
 
 
 
 
 
119
 
120
  # 生成响应
121
  with torch.no_grad():
 
148
 
149
  # 创建极简FastAPI应用
150
  app = FastAPI(
151
+ title="OpenAI API兼容服务",
152
  version="1.0",
153
+ description="专为OpenClaw优化的OpenAI API兼容服务"
154
  )
155
 
156
  # 启动时加载模型
 
175
  @app.get("/")
176
  async def root():
177
  return {
178
+ "message": "OpenAI API兼容服务运行中",
179
  "model_loaded": model is not None,
180
  "api_auth_enabled": API_AUTH_ENABLED,
181
  "endpoints": {
182
+ "v1": "/v1",
183
  "chat_completions": "/v1/chat/completions"
184
  }
185
  }
186
 
187
+ # 添加/v1端点(OpenClaw可能需要)
188
+ @app.get("/v1")
189
+ async def v1_root():
190
+ return {
191
+ "message": "OpenAI v1 API端点",
192
+ "endpoints": {
193
+ "models": "/v1/models",
194
+ "chat_completions": "/v1/chat/completions"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  }
196
+ }
197
+
198
+ # 添加模型列表端点(OpenAI兼容)
199
+ @app.get("/v1/models")
200
+ async def list_models():
201
+ """返回可用的模型列表"""
202
+ return {
203
+ "object": "list",
204
+ "data": [
205
+ {
206
+ "id": "qwen1.5-0.5b-chat",
207
+ "object": "model",
208
+ "created": int(time.time()),
209
+ "owned_by": "qwen"
210
  }
211
+ ]
212
+ }
213
 
214
+ # OpenAI Chat Completions端点(主要端点)
215
  @app.post("/v1/chat/completions")
216
  async def create_chat_completion(
217
  request: Request,
218
  auth_valid: bool = Depends(verify_api_key)
219
  ):
220
+ """OpenAI Chat Completions API兼容端点 - 这是OpenClaw使用的主要端点"""
221
  try:
222
  # 解析请求
223
  data = await request.json()
 
226
  max_tokens = data.get("max_tokens", MAX_TOKENS)
227
  temperature = data.get("temperature", 0.7)
228
 
229
+ # 记录请求详情
230
+ logger.info(f"收到Chat Completions请求: model={model_name}, messages_count={len(messages)}")
231
+
232
  # 检查消息格式
233
  if not messages or not isinstance(messages, list):
234
  return JSONResponse(
 
274
  )
275
 
276
  # 返回OpenAI Chat Completions兼容格式
277
+ response_data = {
278
  "id": f"chatcmpl-{int(time.time())}",
279
  "object": "chat.completion",
280
  "created": int(time.time()),
 
296
  }
297
  }
298
 
299
+ logger.info(f"成功生成响应: {len(result['text'])} 字符")
300
+ return response_data
301
+
302
  except Exception as e:
303
  logger.error(f"Chat Completions API错误: {e}")
304
  return JSONResponse(
 
311
  }
312
  )
313
 
314
+ # 添加兼容性端点(为不同本的OpenClaw提供支持
315
+ @app.post("/chat/completions")
316
+ async def legacy_chat_completion(
317
+ request: Request,
318
+ auth_valid: bool = Depends(verify_api_key)
319
+ ):
320
+ """兼容旧版本OpenClaw的端点"""
321
+ # 直接转发到/v1/chat/completions
322
+ return await create_chat_completion(request, auth_valid)
323
+
324
+ # 添加通用聊天端点
325
  @app.post("/api/chat")
326
+ async def generic_chat_api(
327
  request: Request,
328
  auth_valid: bool = Depends(verify_api_key)
329
  ):
330
+ """通用聊天API端点"""
331
  try:
332
  # 解析请求
333
  data = await request.json()
 
368
  }
369
  )
370
 
371
+ # 返回用格式
372
  return {
373
  "choices": [{
374
  "message": {
 
378
  }
379
 
380
  except Exception as e:
381
+ logger.error(f"通用聊天API错误: {e}")
382
  return JSONResponse(
383
  status_code=500,
384
  content={