rkihacker commited on
Commit
b236837
·
verified ·
1 Parent(s): 9f14d65

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -6
main.py CHANGED
@@ -3,6 +3,7 @@ import httpx
3
  import json
4
  import time
5
  import asyncio
 
6
  from fastapi import FastAPI, HTTPException, Security, Depends, status
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
  from fastapi.responses import StreamingResponse
@@ -13,7 +14,7 @@ from dotenv import load_dotenv
13
  # Load environment variables
14
  load_dotenv()
15
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
16
- SERVER_API_KEY = os.getenv("SERVER_API_KEY") # <-- New key for server auth
17
 
18
  if not REPLICATE_API_TOKEN:
19
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
@@ -21,7 +22,7 @@ if not SERVER_API_KEY:
21
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
22
 
23
  # FastAPI Init
24
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.4 (Server Auth Added)")
25
 
26
  # --- Authentication ---
27
  security = HTTPBearer()
@@ -132,6 +133,11 @@ SUPPORTED_MODELS = {
132
  }
133
 
134
  # --- Core Logic ---
 
 
 
 
 
135
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
136
  prompt_parts = []
137
  system_prompt = None
@@ -281,16 +287,25 @@ async def create_chat_completion(request: ChatCompletionRequest):
281
 
282
  replicate_model_id = SUPPORTED_MODELS[request.model]
283
  formatted = format_messages_for_replicate(request.messages, request.functions)
 
 
 
284
  replicate_input = {
285
  "prompt": formatted["prompt"],
286
- "max_new_tokens": request.max_tokens or 512,
287
  "temperature": request.temperature or 0.7,
288
  "top_p": request.top_p or 1.0
289
  }
 
 
 
 
 
 
290
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
291
  if formatted["image"]: replicate_input["image"] = formatted["image"]
292
 
293
- request_id = f"chatcmpl-{int(time.time())}"
 
294
 
295
  if request.stream:
296
  return StreamingResponse(
@@ -352,7 +367,7 @@ async def root():
352
  """
353
  Root endpoint for health checks. Does not require authentication.
354
  """
355
- return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.4"}
356
 
357
  @app.middleware("http")
358
  async def add_performance_headers(request, call_next):
@@ -360,5 +375,5 @@ async def add_performance_headers(request, call_next):
360
  response = await call_next(request)
361
  process_time = time.time() - start_time
362
  response.headers["X-Process-Time"] = str(round(process_time, 3))
363
- response.headers["X-API-Version"] = "9.2.4"
364
  return response
 
3
  import json
4
  import time
5
  import asyncio
6
+ import secrets # <-- Added for new ID generation
7
  from fastapi import FastAPI, HTTPException, Security, Depends, status
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  from fastapi.responses import StreamingResponse
 
14
  # Load environment variables
15
  load_dotenv()
16
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
17
+ SERVER_API_KEY = os.getenv("SERVER_API_KEY") # <-- Key for server auth
18
 
19
  if not REPLICATE_API_TOKEN:
20
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
 
22
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
23
 
24
  # FastAPI Init
25
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.6 (Dynamic Tokens & ID)")
26
 
27
  # --- Authentication ---
28
  security = HTTPBearer()
 
133
  }
134
 
135
  # --- Core Logic ---
136
+
137
+ def generate_request_id() -> str:
138
+ """Generates a unique request ID in the user-specified format."""
139
+ return f"gen-{int(time.time())}-{secrets.token_hex(8)}"
140
+
141
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
142
  prompt_parts = []
143
  system_prompt = None
 
287
 
288
  replicate_model_id = SUPPORTED_MODELS[request.model]
289
  formatted = format_messages_for_replicate(request.messages, request.functions)
290
+
291
+ # ### MAJOR FIX HERE (Max Tokens) ###
292
+ # Build the payload dynamically.
293
  replicate_input = {
294
  "prompt": formatted["prompt"],
 
295
  "temperature": request.temperature or 0.7,
296
  "top_p": request.top_p or 1.0
297
  }
298
+
299
+ # Only add max_new_tokens if the user *actually* provided it.
300
+ # If not provided, Replicate will use the model's own default.
301
+ if request.max_tokens is not None:
302
+ replicate_input["max_new_tokens"] = request.max_tokens
303
+
304
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
305
  if formatted["image"]: replicate_input["image"] = formatted["image"]
306
 
307
+ # ### MAJOR FIX HERE (Request ID) ###
308
+ request_id = generate_request_id()
309
 
310
  if request.stream:
311
  return StreamingResponse(
 
367
  """
368
  Root endpoint for health checks. Does not require authentication.
369
  """
370
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.6"}
371
 
372
  @app.middleware("http")
373
  async def add_performance_headers(request, call_next):
 
375
  response = await call_next(request)
376
  process_time = time.time() - start_time
377
  response.headers["X-Process-Time"] = str(round(process_time, 3))
378
+ response.headers["X-API-Version"] = "9.2.6"
379
  return response