rkihacker commited on
Commit
5d3f475
·
verified ·
1 Parent(s): b236837

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -10
main.py CHANGED
@@ -3,7 +3,7 @@ import httpx
3
  import json
4
  import time
5
  import asyncio
6
- import secrets # <-- Added for new ID generation
7
  from fastapi import FastAPI, HTTPException, Security, Depends, status
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  from fastapi.responses import StreamingResponse
@@ -22,7 +22,7 @@ if not SERVER_API_KEY:
22
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
23
 
24
  # FastAPI Init
25
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.6 (Dynamic Tokens & ID)")
26
 
27
  # --- Authentication ---
28
  security = HTTPBearer()
@@ -288,8 +288,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
288
  replicate_model_id = SUPPORTED_MODELS[request.model]
289
  formatted = format_messages_for_replicate(request.messages, request.functions)
290
 
291
- # ### MAJOR FIX HERE (Max Tokens) ###
292
- # Build the payload dynamically.
293
  replicate_input = {
294
  "prompt": formatted["prompt"],
295
  "temperature": request.temperature or 0.7,
@@ -297,14 +295,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
297
  }
298
 
299
  # Only add max_new_tokens if the user *actually* provided it.
300
- # If not provided, Replicate will use the model's own default.
301
  if request.max_tokens is not None:
302
  replicate_input["max_new_tokens"] = request.max_tokens
303
 
304
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
305
  if formatted["image"]: replicate_input["image"] = formatted["image"]
306
 
307
- # ### MAJOR FIX HERE (Request ID) ###
308
  request_id = generate_request_id()
309
 
310
  if request.stream:
@@ -323,8 +319,19 @@ async def create_chat_completion(request: ChatCompletionRequest):
323
  resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=300.0)
324
  resp.raise_for_status()
325
  pred = resp.json()
326
- output = "".join(pred.get("output", []))
327
-
 
 
 
 
 
 
 
 
 
 
 
328
  output = output.strip() # Clean up any leading/trailing whitespace
329
 
330
  end_time = time.time()
@@ -360,6 +367,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
360
  except httpx.HTTPStatusError as e:
361
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
362
  except Exception as e:
 
363
  raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
364
 
365
  @app.get("/")
@@ -367,7 +375,7 @@ async def root():
367
  """
368
  Root endpoint for health checks. Does not require authentication.
369
  """
370
- return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.6"}
371
 
372
  @app.middleware("http")
373
  async def add_performance_headers(request, call_next):
@@ -375,5 +383,5 @@ async def add_performance_headers(request, call_next):
375
  response = await call_next(request)
376
  process_time = time.time() - start_time
377
  response.headers["X-Process-Time"] = str(round(process_time, 3))
378
- response.headers["X-API-Version"] = "9.2.6"
379
  return response
 
3
  import json
4
  import time
5
  import asyncio
6
+ import secrets
7
  from fastapi import FastAPI, HTTPException, Security, Depends, status
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  from fastapi.responses import StreamingResponse
 
22
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
23
 
24
  # FastAPI Init
25
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.7 (Non-Stream Fix)")
26
 
27
  # --- Authentication ---
28
  security = HTTPBearer()
 
288
  replicate_model_id = SUPPORTED_MODELS[request.model]
289
  formatted = format_messages_for_replicate(request.messages, request.functions)
290
 
 
 
291
  replicate_input = {
292
  "prompt": formatted["prompt"],
293
  "temperature": request.temperature or 0.7,
 
295
  }
296
 
297
  # Only add max_new_tokens if the user *actually* provided it.
 
298
  if request.max_tokens is not None:
299
  replicate_input["max_new_tokens"] = request.max_tokens
300
 
301
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
302
  if formatted["image"]: replicate_input["image"] = formatted["image"]
303
 
 
304
  request_id = generate_request_id()
305
 
306
  if request.stream:
 
319
  resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=300.0)
320
  resp.raise_for_status()
321
  pred = resp.json()
322
+
323
+ # ### MAJOR FIX HERE (Non-Streaming Join Error) ###
324
+ # Robustly handle the 'output' field which could be a list, string, or null
325
+ raw_output = pred.get("output")
326
+
327
+ if isinstance(raw_output, list):
328
+ output = "".join(raw_output) # Expected case: list of strings
329
+ elif isinstance(raw_output, str):
330
+ output = raw_output # Handle if it's just a single string
331
+ else:
332
+ # Handle None, null, int, bool, or other unexpected types
333
+ output = ""
334
+
335
  output = output.strip() # Clean up any leading/trailing whitespace
336
 
337
  end_time = time.time()
 
367
  except httpx.HTTPStatusError as e:
368
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
369
  except Exception as e:
370
+ # Catch the join error and any others
371
  raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
372
 
373
  @app.get("/")
 
375
  """
376
  Root endpoint for health checks. Does not require authentication.
377
  """
378
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.7"}
379
 
380
  @app.middleware("http")
381
  async def add_performance_headers(request, call_next):
 
383
  response = await call_next(request)
384
  process_time = time.time() - start_time
385
  response.headers["X-Process-Time"] = str(round(process_time, 3))
386
+ response.headers["X-API-Version"] = "9.2.7"
387
  return response