rkihacker commited on
Commit
6c8dce7
·
verified ·
1 Parent(s): 63b36c2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -32
main.py CHANGED
@@ -20,7 +20,7 @@ if not REPLICATE_API_TOKEN:
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
- version="3.0.0 (Production Grade)",
24
  )
25
 
26
  # --- Pydantic Models ---
@@ -36,61 +36,48 @@ class ChatMessage(BaseModel):
36
  class OpenAIChatCompletionRequest(BaseModel):
37
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
38
 
39
- # --- Model Mapping with Explicit Version Hashes (Inspired by LiteLLM) ---
40
  SUPPORTED_MODELS = {
41
  "llama3-8b-instruct": {
42
  "id": "meta/meta-llama-3-8b-instruct",
43
- "version": "02741d1be9a932e6566058d4c92ab80332f143003b5a874f63c9b743e4f3583c",
44
  "input_type": "messages"
45
  },
46
  "claude-4.5-haiku": {
47
  "id": "anthropic/claude-4.5-haiku",
48
- "version": "311c5ff9b9f71c9ebd401b34a41ce604a8b735def3a4aad56f671302b5c56784",
49
  "input_type": "prompt"
50
  }
51
  }
52
 
53
  # --- Helper Functions ---
54
-
55
- def build_replicate_request_body(request: OpenAIChatCompletionRequest, model_details: dict) -> dict:
56
- """Builds the complete request body, including the crucial version hash."""
57
  input_payload = {}
58
 
59
- # Handle model-specific input format (prompt vs messages)
60
  if model_details["input_type"] == "prompt":
61
  prompt_parts = []
62
  system_prompt = None
63
  for msg in request.messages:
64
- if msg.role == "system":
65
- system_prompt = str(msg.content)
66
- elif msg.role == "user":
67
- prompt_parts.append(f"User: {msg.content}")
68
- elif msg.role == "assistant":
69
- prompt_parts.append(f"Assistant: {msg.content}")
70
- prompt_parts.append("Assistant:") # Cue the model to respond
71
  input_payload["prompt"] = "\n".join(prompt_parts)
72
  if system_prompt: input_payload["system_prompt"] = system_prompt
73
  else: # "messages"
74
  input_payload["messages"] = [msg.dict() for msg in request.messages]
75
 
76
- # Add common parameters
77
  if request.max_tokens is not None: input_payload["max_new_tokens"] = request.max_tokens
78
  if request.temperature is not None: input_payload["temperature"] = request.temperature
79
  if request.top_p is not None: input_payload["top_p"] = request.top_p
80
-
81
- return {
82
- "version": model_details["version"],
83
- "input": input_payload
84
- }
85
 
86
- async def stream_replicate_native_sse(model_id: str, request_body: dict):
87
- """Connects to Replicate's native SSE stream for true token-by-token streaming."""
88
- # Note: We call the generic predictions endpoint when providing a version hash.
89
- url = "https://api.replicate.com/v1/predictions"
90
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
91
 
92
- # Add stream=True to the request body
93
- request_body["stream"] = True
94
 
95
  async with httpx.AsyncClient(timeout=300) as client:
96
  prediction = None
@@ -153,18 +140,18 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
153
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
154
 
155
  model_details = SUPPORTED_MODELS[model_key]
156
- replicate_request_body = build_replicate_request_body(request, model_details)
157
 
158
  if request.stream:
159
- return EventSourceResponse(stream_replicate_native_sse(model_details["id"], replicate_request_body))
160
 
161
- # Synchronous request
162
- url = "https://api.replicate.com/v1/predictions"
163
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
164
 
165
  async with httpx.AsyncClient(timeout=150) as client:
166
  try:
167
- response = await client.post(url, headers=headers, json=replicate_request_body)
168
  response.raise_for_status()
169
  prediction = response.json()
170
  output = "".join(prediction.get("output", []))
 
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
+ version="4.0.0 (Stable & Correct)",
24
  )
25
 
26
  # --- Pydantic Models ---
 
36
  class OpenAIChatCompletionRequest(BaseModel):
37
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
38
 
39
+ # --- Model Mapping (Simplified for direct endpoint usage) ---
40
  SUPPORTED_MODELS = {
41
  "llama3-8b-instruct": {
42
  "id": "meta/meta-llama-3-8b-instruct",
 
43
  "input_type": "messages"
44
  },
45
  "claude-4.5-haiku": {
46
  "id": "anthropic/claude-4.5-haiku",
 
47
  "input_type": "prompt"
48
  }
49
  }
50
 
51
  # --- Helper Functions ---
52
+ def prepare_replicate_input(request: OpenAIChatCompletionRequest, model_details: dict) -> Dict[str, Any]:
53
+ """Prepares the 'input' dictionary for Replicate, handling model-specific formats."""
 
54
  input_payload = {}
55
 
 
56
  if model_details["input_type"] == "prompt":
57
  prompt_parts = []
58
  system_prompt = None
59
  for msg in request.messages:
60
+ if msg.role == "system": system_prompt = str(msg.content)
61
+ elif msg.role == "user": prompt_parts.append(f"User: {msg.content}")
62
+ elif msg.role == "assistant": prompt_parts.append(f"Assistant: {msg.content}")
63
+ prompt_parts.append("Assistant:")
 
 
 
64
  input_payload["prompt"] = "\n".join(prompt_parts)
65
  if system_prompt: input_payload["system_prompt"] = system_prompt
66
  else: # "messages"
67
  input_payload["messages"] = [msg.dict() for msg in request.messages]
68
 
 
69
  if request.max_tokens is not None: input_payload["max_new_tokens"] = request.max_tokens
70
  if request.temperature is not None: input_payload["temperature"] = request.temperature
71
  if request.top_p is not None: input_payload["top_p"] = request.top_p
72
+ return input_payload
 
 
 
 
73
 
74
+ async def stream_replicate_native_sse(model_id: str, input_payload: dict):
75
+ """Connects to Replicate's native SSE stream using the model-specific endpoint."""
76
+ url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
 
77
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
78
 
79
+ # The request body is now simple and correct
80
+ request_body = {"input": input_payload, "stream": True}
81
 
82
  async with httpx.AsyncClient(timeout=300) as client:
83
  prediction = None
 
140
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
141
 
142
  model_details = SUPPORTED_MODELS[model_key]
143
+ replicate_input = prepare_replicate_input(request, model_details)
144
 
145
  if request.stream:
146
+ return EventSourceResponse(stream_replicate_native_sse(model_details["id"], replicate_input))
147
 
148
+ # Synchronous Request
149
+ url = f"https://api.replicate.com/v1/models/{model_details['id']}/predictions"
150
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
151
 
152
  async with httpx.AsyncClient(timeout=150) as client:
153
  try:
154
+ response = await client.post(url, headers=headers, json={"input": replicate_input})
155
  response.raise_for_status()
156
  prediction = response.json()
157
  output = "".join(prediction.get("output", []))