rkihacker commited on
Commit
91b7eb3
·
verified ·
1 Parent(s): e014ad9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -51
main.py CHANGED
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="4.1.0 (Context Fixed)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
@@ -29,47 +29,52 @@ class OpenAIChatCompletionRequest(BaseModel):
29
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
30
 
31
  # --- Supported Models ---
32
- # Maps OpenAI-friendly names to Replicate model paths
33
  SUPPORTED_MODELS = {
34
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
36
- # You can add more models here
37
  }
38
 
39
  # --- Core Logic ---
40
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
41
  """
42
- Formats the input for Replicate API, preserving the conversational context.
 
 
43
  """
44
  payload = {}
45
 
46
- # --- CONTEXT FIX START ---
47
- # Modern chat models on Replicate (like Llama 3 and Claude 4.5) expect
48
- # the 'messages' array directly, just like OpenAI.
49
- # We no longer need to flatten the conversation into a single prompt string.
50
-
51
- # Extract system prompt if it exists, as some models take it as a separate parameter.
52
- messages_for_payload = []
53
  system_prompt = None
 
54
  for msg in request.messages:
55
  if msg.role == "system":
56
- # Claude and some other models prefer a dedicated system_prompt field.
57
  system_prompt = str(msg.content)
58
- else:
59
- # Handle user/assistant roles. Convert Pydantic model to a standard dict.
60
- messages_for_payload.append(msg.dict())
61
-
62
- # The main input for conversation is the 'messages' array.
63
- payload["messages"] = messages_for_payload
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Add system_prompt to the payload if it was found.
66
  if system_prompt:
67
  payload["system_prompt"] = system_prompt
68
 
69
- # --- CONTEXT FIX END ---
70
 
71
  # Map common OpenAI parameters to Replicate equivalents
72
- # Note: Replicate's parameter for max tokens is often 'max_new_tokens'
73
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
74
  if request.temperature: payload["temperature"] = request.temperature
75
  if request.top_p: payload["top_p"] = request.top_p
@@ -78,13 +83,11 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
78
 
79
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
80
  """Handles the full streaming lifecycle using standard Replicate endpoints."""
81
- # 1. Start Prediction
82
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
83
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
84
 
85
  async with httpx.AsyncClient(timeout=60.0) as client:
86
  try:
87
- # Request a streaming prediction
88
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
89
  response.raise_for_status()
90
  prediction = response.json()
@@ -98,15 +101,13 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
98
  except httpx.HTTPStatusError as e:
99
  error_details = e.response.text
100
  try:
101
- # Try to parse the error for a cleaner message
102
  error_json = e.response.json()
103
  error_details = error_json.get("detail", error_details)
104
  except json.JSONDecodeError:
105
- pass # Use raw text if not JSON
106
  yield json.dumps({"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}})
107
  return
108
 
109
- # 2. Connect to the provided Stream URL and process Server-Sent Events (SSE)
110
  try:
111
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
112
  current_event = None
@@ -117,9 +118,7 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
117
  data = line[len("data:"):].strip()
118
 
119
  if current_event == "output":
120
- # The 'output' event for chat models sends one token at a time as a plain string.
121
- # We don't need to parse it as JSON.
122
- if data: # Ensure we don't send empty chunks
123
  chunk = {
124
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
125
  "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
@@ -127,21 +126,16 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
127
  yield json.dumps(chunk)
128
 
129
  elif current_event == "done":
130
- # The 'done' event signals the end of the stream.
131
  break
132
  except httpx.ReadTimeout:
133
- # Handle cases where the stream times out
134
  yield json.dumps({"error": {"message": "Stream timed out.", "type": "timeout_error"}})
135
  return
136
 
137
-
138
- # 3. Send the final termination chunk in OpenAI format
139
  final_chunk = {
140
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
141
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
142
  }
143
  yield json.dumps(final_chunk)
144
- # Some clients (like curl) expect a final "[DONE]" message to close the connection.
145
  yield "[DONE]"
146
 
147
  # --- Endpoints ---
@@ -160,34 +154,21 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
160
  replicate_input = prepare_replicate_input(request)
161
 
162
  if request.stream:
163
- # Return a streaming response
164
  return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input), media_type="text/event-stream")
165
 
166
  # Non-streaming fallback
167
  url = f"https://api.replicate.com/v1/models/{replicate_id}/predictions"
168
- headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"} # Increased wait time
169
  async with httpx.AsyncClient() as client:
170
  try:
171
  resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=130.0)
172
  resp.raise_for_status()
173
  pred = resp.json()
174
- # The output of chat models is typically a list of strings (tokens)
175
  output = "".join(pred.get("output", []))
176
  return {
177
- "id": pred.get("id"),
178
- "object": "chat.completion",
179
- "created": int(time.time()),
180
- "model": request.model,
181
- "choices": [{
182
- "index": 0,
183
- "message": {"role": "assistant", "content": output},
184
- "finish_reason": "stop"
185
- }],
186
- "usage": { # Placeholder usage object
187
- "prompt_tokens": 0,
188
- "completion_tokens": 0,
189
- "total_tokens": 0
190
- }
191
  }
192
  except httpx.HTTPStatusError as e:
193
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="4.2.0 (Prompt Format Fixed)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
 
29
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
30
 
31
  # --- Supported Models ---
 
32
  SUPPORTED_MODELS = {
33
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
34
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
 
35
  }
36
 
37
  # --- Core Logic ---
38
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
39
  """
40
+ Formats the input for Replicate API. This function now correctly builds a
41
+ single prompt string from the message history, which is required by
42
+ Replicate's endpoints for models like Claude and Llama 3.
43
  """
44
  payload = {}
45
 
46
+ # --- PROMPT FORMAT FIX START ---
47
+ prompt_parts = []
 
 
 
 
 
48
  system_prompt = None
49
+
50
  for msg in request.messages:
51
  if msg.role == "system":
52
+ # Extract system prompt, as it's a separate parameter for many models
53
  system_prompt = str(msg.content)
54
+ elif msg.role == "user":
55
+ # Format user messages
56
+ content = msg.content
57
+ if isinstance(content, list): # Handle potential future vision models
58
+ text_parts = [item.get("text", "") for item in content if item.get("type") == "text"]
59
+ content = " ".join(text_parts)
60
+ prompt_parts.append(f"User: {content}")
61
+ elif msg.role == "assistant":
62
+ # Format assistant messages
63
+ prompt_parts.append(f"Assistant: {msg.content}")
64
+
65
+ # Add the final "Assistant:" turn to prompt the model for a response.
66
+ # This is a standard convention for many chat models when using a single prompt string.
67
+ prompt_parts.append("Assistant:")
68
+
69
+ # The main input is a single 'prompt' string with turns separated by newlines.
70
+ payload["prompt"] = "\n\n".join(prompt_parts)
71
 
 
72
  if system_prompt:
73
  payload["system_prompt"] = system_prompt
74
 
75
+ # --- PROMPT FORMAT FIX END ---
76
 
77
  # Map common OpenAI parameters to Replicate equivalents
 
78
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
79
  if request.temperature: payload["temperature"] = request.temperature
80
  if request.top_p: payload["top_p"] = request.top_p
 
83
 
84
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
85
  """Handles the full streaming lifecycle using standard Replicate endpoints."""
 
86
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
87
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
88
 
89
  async with httpx.AsyncClient(timeout=60.0) as client:
90
  try:
 
91
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
92
  response.raise_for_status()
93
  prediction = response.json()
 
101
  except httpx.HTTPStatusError as e:
102
  error_details = e.response.text
103
  try:
 
104
  error_json = e.response.json()
105
  error_details = error_json.get("detail", error_details)
106
  except json.JSONDecodeError:
107
+ pass
108
  yield json.dumps({"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}})
109
  return
110
 
 
111
  try:
112
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
113
  current_event = None
 
118
  data = line[len("data:"):].strip()
119
 
120
  if current_event == "output":
121
+ if data:
 
 
122
  chunk = {
123
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
124
  "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
 
126
  yield json.dumps(chunk)
127
 
128
  elif current_event == "done":
 
129
  break
130
  except httpx.ReadTimeout:
 
131
  yield json.dumps({"error": {"message": "Stream timed out.", "type": "timeout_error"}})
132
  return
133
 
 
 
134
  final_chunk = {
135
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
136
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
137
  }
138
  yield json.dumps(final_chunk)
 
139
  yield "[DONE]"
140
 
141
  # --- Endpoints ---
 
154
  replicate_input = prepare_replicate_input(request)
155
 
156
  if request.stream:
 
157
  return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input), media_type="text/event-stream")
158
 
159
  # Non-streaming fallback
160
  url = f"https://api.replicate.com/v1/models/{replicate_id}/predictions"
161
+ headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
162
  async with httpx.AsyncClient() as client:
163
  try:
164
  resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=130.0)
165
  resp.raise_for_status()
166
  pred = resp.json()
 
167
  output = "".join(pred.get("output", []))
168
  return {
169
+ "id": pred.get("id"), "object": "chat.completion", "created": int(time.time()), "model": request.model,
170
+ "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
171
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
 
 
 
 
 
 
 
 
 
 
 
172
  }
173
  except httpx.HTTPStatusError as e:
174
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")