rkihacker commited on
Commit
a135be4
·
verified ·
1 Parent(s): c466862

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +45 -59
main.py CHANGED
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="6.0.0 (Claude Vision Enabled)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
@@ -30,71 +30,59 @@ class OpenAIChatCompletionRequest(BaseModel):
30
 
31
  # --- Supported Models ---
32
  SUPPORTED_MODELS = {
33
- # Text Models
34
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
- # Anthropic Claude Models (Vision Enabled)
36
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
37
  "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
38
- # Other Vision Model (uses different input format)
39
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
40
  }
41
 
42
  # --- Core Logic ---
43
- def prepare_replicate_input(request: OpenAIChatCompletionRequest, replicate_id: str) -> Dict[str, Any]:
44
  """
45
- Formats the input for the Replicate API based on the model's requirements.
46
- - Modern Claude models accept the 'messages' array directly for multimodal input.
47
- - Other models may require a flattened 'prompt' string and a separate 'image' field.
48
  """
49
  payload = {}
50
 
51
- # --- MODEL-AWARE PAYLOAD PREPARATION ---
52
- if "anthropic/claude" in replicate_id:
53
- # These models support the OpenAI-like 'messages' array directly.
54
- # This is the correct way to handle multimodal (image) inputs for Claude.
55
- messages_for_payload = []
56
- system_prompt = None
57
- for msg in request.messages:
58
- if msg.role == "system":
59
- system_prompt = str(msg.content)
 
 
 
 
 
 
 
 
 
 
 
 
60
  else:
61
- # Convert Pydantic model to dict and add to the list
62
- messages_for_payload.append(msg.dict())
63
-
64
- payload["messages"] = messages_for_payload
65
- if system_prompt:
66
- payload["system_prompt"] = system_prompt
67
 
68
- else:
69
- # Fallback for models that require a flattened prompt string (e.g., Llama, Llava)
70
- prompt_parts = []
71
- image_input = None
72
- for msg in request.messages:
73
- if msg.role == "system":
74
- # System prompts are handled differently or prepended by the user
75
- # for these models, often as part of the main prompt.
76
- # For simplicity, we'll place it at the beginning.
77
- prompt_parts.insert(0, str(msg.content))
78
- elif msg.role == "assistant":
79
- prompt_parts.append(f"Assistant: {msg.content}")
80
- elif msg.role == "user":
81
- user_text_content = ""
82
- if isinstance(msg.content, list):
83
- for item in msg.content:
84
- if item.get("type") == "text":
85
- user_text_content += item.get("text", "")
86
- elif item.get("type") == "image_url":
87
- image_url_data = item.get("image_url", {})
88
- image_input = image_url_data.get("url")
89
- else:
90
- user_text_content = str(msg.content)
91
- prompt_parts.append(f"User: {user_text_content}")
92
-
93
- prompt_parts.append("Assistant:")
94
- payload["prompt"] = "\n\n".join(prompt_parts)
95
- if image_input:
96
- payload["image"] = image_input
97
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # Map common OpenAI parameters to Replicate equivalents
99
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
100
  if request.temperature: payload["temperature"] = request.temperature
@@ -140,7 +128,7 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
140
  if current_event == "output":
141
  if data:
142
  chunk = {
143
- "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_id,
144
  "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
145
  }
146
  yield json.dumps(chunk)
@@ -151,7 +139,7 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
151
  return
152
 
153
  final_chunk = {
154
- "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_id,
155
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
156
  }
157
  yield json.dumps(final_chunk)
@@ -169,15 +157,13 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
169
  if request.model not in SUPPORTED_MODELS:
170
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
171
 
172
- replicate_id = SUPPORTED_MODELS[request.model]
173
- # Pass the replicate_id to the prepare function so it knows which format to use
174
- replicate_input = prepare_replicate_input(request, replicate_id)
175
 
176
  if request.stream:
177
- return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input), media_type="text/event-stream")
178
 
179
  # Non-streaming fallback
180
- url = f"https://api.replicate.com/v1/models/{replicate_id}/predictions"
181
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
182
  async with httpx.AsyncClient() as client:
183
  try:
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="7.0.0 (Unified Prompt Fix)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
 
30
 
31
  # --- Supported Models ---
32
  SUPPORTED_MODELS = {
 
33
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
 
34
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
35
  "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
 
36
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
37
  }
38
 
39
  # --- Core Logic ---
40
+ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
41
  """
42
+ Formats the input for the Replicate API. This function now uses a unified approach
43
+ for all models, flattening the message history into a single 'prompt' string
44
+ and handling images separately, as required by Replicate's API.
45
  """
46
  payload = {}
47
 
48
+ prompt_parts = []
49
+ system_prompt = None
50
+ image_input = None
51
+
52
+ for msg in request.messages:
53
+ if msg.role == "system":
54
+ # Extract system prompt; it will be a separate parameter.
55
+ system_prompt = str(msg.content)
56
+ elif msg.role == "assistant":
57
+ prompt_parts.append(f"Assistant: {msg.content}")
58
+ elif msg.role == "user":
59
+ user_text_content = ""
60
+ if isinstance(msg.content, list):
61
+ # Handle multimodal (vision) input from OpenAI format
62
+ for item in msg.content:
63
+ if item.get("type") == "text":
64
+ user_text_content += item.get("text", "")
65
+ elif item.get("type") == "image_url":
66
+ image_url_data = item.get("image_url", {})
67
+ # The 'image' parameter is used by Claude, Llava, etc., on Replicate
68
+ image_input = image_url_data.get("url")
69
  else:
70
+ user_text_content = str(msg.content)
 
 
 
 
 
71
 
72
+ prompt_parts.append(f"User: {user_text_content}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # The final "Assistant:" turn prompts the model for a response.
75
+ prompt_parts.append("Assistant:")
76
+
77
+ # All models on Replicate's API expect a single 'prompt' string.
78
+ payload["prompt"] = "\n\n".join(prompt_parts)
79
+
80
+ if system_prompt:
81
+ payload["system_prompt"] = system_prompt
82
+
83
+ if image_input:
84
+ payload["image"] = image_input
85
+
86
  # Map common OpenAI parameters to Replicate equivalents
87
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
88
  if request.temperature: payload["temperature"] = request.temperature
 
128
  if current_event == "output":
129
  if data:
130
  chunk = {
131
+ "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
132
  "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
133
  }
134
  yield json.dumps(chunk)
 
139
  return
140
 
141
  final_chunk = {
142
+ "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
143
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
144
  }
145
  yield json.dumps(final_chunk)
 
157
  if request.model not in SUPPORTED_MODELS:
158
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
159
 
160
+ replicate_input = prepare_replicate_input(request)
 
 
161
 
162
  if request.stream:
163
+ return EventSourceResponse(stream_replicate_sse(SUPPORTED_MODELS[request.model], replicate_input), media_type="text/event-stream")
164
 
165
  # Non-streaming fallback
166
+ url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
167
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
168
  async with httpx.AsyncClient() as client:
169
  try: