rkihacker commited on
Commit
7b0c05f
·
verified ·
1 Parent(s): 0f4286e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +111 -99
main.py CHANGED
@@ -1,14 +1,13 @@
1
-
2
  import os
3
  import httpx
4
  import json
5
  import time
6
  from fastapi import FastAPI, HTTPException
7
- from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel, Field
9
  from typing import List, Dict, Any, Optional, Union, Literal
10
  from dotenv import load_dotenv
11
- from sse_starlette.sse import EventSourceResponse
12
 
13
  # Load environment variables
14
  load_dotenv()
@@ -17,23 +16,36 @@ if not REPLICATE_API_TOKEN:
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.0.0 (Definitive Streaming Fix)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
24
- id: str; object: str = "model"; created: int = Field(default_factory=lambda: int(time.time())); owned_by: str = "replicate"
 
 
 
 
25
  class ModelList(BaseModel):
26
- object: str = "list"; data: List[ModelCard] = []
 
 
27
  class ChatMessage(BaseModel):
28
- role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]
 
 
29
  class OpenAIChatCompletionRequest(BaseModel):
30
- model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
 
 
 
 
 
31
 
32
  # --- Supported Models ---
33
  SUPPORTED_MODELS = {
34
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
- "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
36
- "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
37
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
38
  }
39
 
@@ -80,140 +92,137 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
80
 
81
  return payload
82
 
83
- async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
84
- """Handles the full streaming lifecycle with correct whitespace preservation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
86
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
87
 
 
 
 
88
  async with httpx.AsyncClient(timeout=60.0) as client:
 
89
  try:
90
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
91
  response.raise_for_status()
92
  prediction = response.json()
93
  stream_url = prediction.get("urls", {}).get("stream")
94
- prediction_id = prediction.get("id", "stream-unknown")
 
95
  if not stream_url:
96
- yield f"data: {json.dumps({'error': {'message': 'Model did not return a stream URL.'}})}\n\n"
 
97
  return
 
98
  except httpx.HTTPStatusError as e:
99
  error_details = e.response.text
100
  try:
101
  error_json = e.response.json()
102
  error_details = error_json.get("detail", error_details)
103
  except json.JSONDecodeError: pass
104
- yield f"data: {json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})}\n\n"
 
105
  return
106
-
 
107
  try:
108
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
109
  current_event = None
110
  async for line in sse.aiter_lines():
111
- if not line: # Skip empty lines
112
  continue
113
  if line.startswith("event:"):
114
  current_event = line[len("event:"):].strip()
115
  elif line.startswith("data:"):
116
- # FIXED: Preserve all whitespace including leading/trailing spaces
117
- raw_data = line[5:] # Remove "data:" prefix
118
-
119
- # Handle empty data lines (preserve them)
120
- if not raw_data:
121
- continue
122
-
123
- # Remove only the optional single space after data: if present
124
- # This is per SSE spec and preserves actual content spaces
125
- if raw_data.startswith(" "):
126
- data_content = raw_data[1:] # Remove the first space only
127
- else:
128
- data_content = raw_data
129
 
 
 
 
 
130
  if current_event == "output":
131
- if not data_content:
132
  continue
133
-
134
  content_token = ""
135
  try:
136
- # Handle JSON-encoded strings properly (including spaces)
137
- content_token = json.loads(data_content)
 
138
  except (json.JSONDecodeError, TypeError):
139
- # Handle plain text tokens (preserve as-is)
140
- content_token = data_content
141
 
142
- # Create chunk with exact format you specified
143
  chunk = {
 
 
 
 
 
144
  "choices": [{
 
145
  "delta": {"content": content_token},
146
  "finish_reason": None,
147
- "index": 0,
148
  "logprobs": None,
149
  "native_finish_reason": None
150
- }],
151
- "created": int(time.time()),
152
- "id": f"gen-{int(time.time())}-{prediction_id[-12:]}", # Format like your example
153
- "model": replicate_model_id,
154
- "object": "chat.completion.chunk",
155
- "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
156
  }
157
  yield f"data: {json.dumps(chunk)}\n\n"
158
-
159
  elif current_event == "done":
160
- # Send usage chunk before done
161
- usage_chunk = {
162
- "choices": [{
163
- "delta": {},
164
- "finish_reason": None,
165
- "index": 0,
166
- "logprobs": None,
167
- "native_finish_reason": None
168
- }],
169
- "created": int(time.time()),
170
- "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
171
- "model": replicate_model_id,
172
- "object": "chat.completion.chunk",
173
- "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate",
174
- "usage": {
175
- "cache_discount": 0,
176
- "completion_tokens": 0,
177
- "completion_tokens_details": {"image_tokens": 0, "reasoning_tokens": 0},
178
- "cost": 0,
179
- "cost_details": {
180
- "upstream_inference_completions_cost": 0,
181
- "upstream_inference_cost": None,
182
- "upstream_inference_prompt_cost": 0
183
- },
184
- "input_tokens": 0,
185
- "is_byok": False,
186
- "prompt_tokens": 0,
187
- "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
188
- "total_tokens": 0
189
- }
190
- }
191
- yield f"data: {json.dumps(usage_chunk)}\n\n"
192
-
193
- # Send final chunk with stop reason
194
- final_chunk = {
195
- "choices": [{
196
- "delta": {},
197
- "finish_reason": "stop",
198
- "index": 0,
199
- "logprobs": None,
200
- "native_finish_reason": "end_turn"
201
- }],
202
- "created": int(time.time()),
203
- "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
204
- "model": replicate_model_id,
205
- "object": "chat.completion.chunk",
206
- "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
207
- }
208
- yield f"data: {json.dumps(final_chunk)}\n\n"
209
  break
210
  except httpx.ReadTimeout:
211
- yield f"data: {json.dumps({'error': {'message': 'Stream timed out.', 'type': 'timeout_error'}})}\n\n"
 
212
  return
213
 
214
- # Send [DONE] event
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  yield "data: [DONE]\n\n"
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  # --- Endpoints ---
218
  @app.get("/v1/models")
219
  async def list_models():
@@ -224,13 +233,16 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
224
  if request.model not in SUPPORTED_MODELS:
225
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
226
 
 
227
  replicate_input = prepare_replicate_input(request)
228
 
229
  if request.stream:
230
- return EventSourceResponse(stream_replicate_sse(SUPPORTED_MODELS[request.model], replicate_input), media_type="text/event-stream")
 
 
231
 
232
  # Non-streaming fallback
233
- url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
234
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
235
  async with httpx.AsyncClient() as client:
236
  try:
@@ -244,4 +256,4 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
244
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
245
  }
246
  except httpx.HTTPStatusError as e:
247
- raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
 
 
1
  import os
2
  import httpx
3
  import json
4
  import time
5
  from fastapi import FastAPI, HTTPException
6
+ from fastapi.responses import Response
7
  from pydantic import BaseModel, Field
8
  from typing import List, Dict, Any, Optional, Union, Literal
9
  from dotenv import load_dotenv
10
+ import asyncio
11
 
12
  # Load environment variables
13
  load_dotenv()
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="10.0.0 (Enhanced Chunk Formatting)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
23
+ id: str
24
+ object: str = "model"
25
+ created: int = Field(default_factory=lambda: int(time.time()))
26
+ owned_by: str = "replicate"
27
+
28
  class ModelList(BaseModel):
29
+ object: str = "list"
30
+ data: List[ModelCard] = []
31
+
32
  class ChatMessage(BaseModel):
33
+ role: Literal["system", "user", "assistant", "tool"]
34
+ content: Union[str, List[Dict[str, Any]]]
35
+
36
  class OpenAIChatCompletionRequest(BaseModel):
37
+ model: str
38
+ messages: List[ChatMessage]
39
+ temperature: Optional[float] = 0.7
40
+ top_p: Optional[float] = 1.0
41
+ max_tokens: Optional[int] = None
42
+ stream: Optional[bool] = False
43
 
44
  # --- Supported Models ---
45
  SUPPORTED_MODELS = {
46
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
47
+ "claude-4.5-haiku": "anthropic/claude-4.5-haiku", # Note: Name changed for clarity
48
+ "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet", # Note: Name changed for clarity
49
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
50
  }
51
 
 
92
 
93
  return payload
94
 
95
+ def get_provider(replicate_model_id: str) -> str:
96
+ """Infers the provider from the Replicate model ID."""
97
+ if replicate_model_id.startswith("meta/"):
98
+ return "Meta"
99
+ if replicate_model_id.startswith("anthropic/"):
100
+ return "Anthropic"
101
+ if "llava" in replicate_model_id:
102
+ return "Llava"
103
+ return "Replicate"
104
+
105
+ async def stream_replicate_sse(replicate_model_id: str, requested_model_name: str, input_payload: dict):
106
+ """
107
+ Handles the full streaming lifecycle with corrected whitespace preservation
108
+ and the new, detailed chunk format.
109
+ """
110
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
111
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
112
 
113
+ # Identify provider for the response chunks
114
+ provider = get_provider(replicate_model_id)
115
+
116
  async with httpx.AsyncClient(timeout=60.0) as client:
117
+ # 1. Create the prediction and get the stream URL
118
  try:
119
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
120
  response.raise_for_status()
121
  prediction = response.json()
122
  stream_url = prediction.get("urls", {}).get("stream")
123
+ prediction_id = prediction.get("id", f"stream-{int(time.time())}")
124
+
125
  if not stream_url:
126
+ error_chunk = { "error": {"message": "Model did not return a stream URL."} }
127
+ yield f"data: {json.dumps(error_chunk)}\n\n"
128
  return
129
+
130
  except httpx.HTTPStatusError as e:
131
  error_details = e.response.text
132
  try:
133
  error_json = e.response.json()
134
  error_details = error_json.get("detail", error_details)
135
  except json.JSONDecodeError: pass
136
+ error_chunk = {"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}}
137
+ yield f"data: {json.dumps(error_chunk)}\n\n"
138
  return
139
+
140
+ # 2. Connect to the SSE stream and yield formatted chunks
141
  try:
142
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
143
  current_event = None
144
  async for line in sse.aiter_lines():
145
+ if not line:
146
  continue
147
  if line.startswith("event:"):
148
  current_event = line[len("event:"):].strip()
149
  elif line.startswith("data:"):
150
+ # Get the raw payload after "data:"
151
+ raw_payload = line[len("data:"):]
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ # The SSE spec allows an optional leading space. Remove it.
154
+ # This robustly prevents parsing errors without destroying content.
155
+ payload = raw_payload.lstrip(" ")
156
+
157
  if current_event == "output":
158
+ if not payload:
159
  continue
160
+
161
  content_token = ""
162
  try:
163
+ # This handles JSON-encoded strings like "\" Hello\"" and correctly
164
+ # preserves all whitespace, including single spaces. This is the fix.
165
+ content_token = json.loads(payload)
166
  except (json.JSONDecodeError, TypeError):
167
+ # Fallback for plain text tokens if Replicate changes format
168
+ content_token = payload
169
 
170
+ # Build the new, detailed chunk structure
171
  chunk = {
172
+ "id": prediction_id,
173
+ "object": "chat.completion.chunk",
174
+ "created": int(time.time()),
175
+ "model": requested_model_name,
176
+ "provider": provider,
177
  "choices": [{
178
+ "index": 0,
179
  "delta": {"content": content_token},
180
  "finish_reason": None,
 
181
  "logprobs": None,
182
  "native_finish_reason": None
183
+ }]
 
 
 
 
 
184
  }
185
  yield f"data: {json.dumps(chunk)}\n\n"
186
+
187
  elif current_event == "done":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  break
189
  except httpx.ReadTimeout:
190
+ error_chunk = {"error": {"message": "Stream timed out.", "type": "timeout_error"}}
191
+ yield f"data: {json.dumps(error_chunk)}\n\n"
192
  return
193
 
194
+ # 3. Send the final chunk with finish_reason
195
+ final_chunk = {
196
+ "id": prediction_id,
197
+ "object": "chat.completion.chunk",
198
+ "created": int(time.time()),
199
+ "model": requested_model_name,
200
+ "provider": provider,
201
+ "choices": [{
202
+ "index": 0,
203
+ "delta": {},
204
+ "finish_reason": "stop",
205
+ "logprobs": None,
206
+ "native_finish_reason": "end_turn"
207
+ }]
208
+ }
209
+ yield f"data: {json.dumps(final_chunk)}\n\n"
210
  yield "data: [DONE]\n\n"
211
 
212
+ # A simple EventSourceResponse implementation if sse-starlette is not preferred
213
+ async def create_sse_response(generator):
214
+ headers = {
215
+ 'Content-Type': 'text/event-stream',
216
+ 'Cache-Control': 'no-cache',
217
+ 'Connection': 'keep-alive',
218
+ }
219
+ async def stream():
220
+ async for chunk in generator:
221
+ yield chunk
222
+ await asyncio.sleep(0) # Yield control to the event loop
223
+ return Response(stream(), headers=headers)
224
+
225
+
226
  # --- Endpoints ---
227
  @app.get("/v1/models")
228
  async def list_models():
 
233
  if request.model not in SUPPORTED_MODELS:
234
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
235
 
236
+ replicate_model_id = SUPPORTED_MODELS[request.model]
237
  replicate_input = prepare_replicate_input(request)
238
 
239
  if request.stream:
240
+ # Use the custom generator with the detailed chunk format
241
+ generator = stream_replicate_sse(replicate_model_id, request.model, replicate_input)
242
+ return await create_sse_response(generator)
243
 
244
  # Non-streaming fallback
245
+ url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
246
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
247
  async with httpx.AsyncClient() as client:
248
  try:
 
256
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
257
  }
258
  except httpx.HTTPStatusError as e:
259
+ raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")