rkihacker commited on
Commit
c466862
·
verified ·
1 Parent(s): df13528

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +61 -48
main.py CHANGED
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="5.0.0 (Vision Enabled)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
@@ -30,56 +30,70 @@ class OpenAIChatCompletionRequest(BaseModel):
30
 
31
  # --- Supported Models ---
32
  SUPPORTED_MODELS = {
 
33
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
34
- "claude-4.5-haiku": "anthropic/claude-4.5-haiku" # This model supports vision
 
 
 
 
35
  }
36
 
37
  # --- Core Logic ---
38
- def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
39
  """
40
- Formats the input for Replicate API, handling both text and vision (image) inputs.
 
 
41
  """
42
  payload = {}
43
- prompt_parts = []
44
- system_prompt = None
45
- image_url = None # Variable to hold the image data URI
46
-
47
- for msg in request.messages:
48
- if msg.role == "system":
49
- system_prompt = str(msg.content)
50
- elif msg.role == "user":
51
- # --- VISION SUPPORT START ---
52
- if isinstance(msg.content, list):
53
- # This is a multi-modal request (text + image)
54
- text_content = ""
55
- for part in msg.content:
56
- if part.get("type") == "text":
57
- text_content += part.get("text", "") + "\n"
58
- elif part.get("type") == "image_url":
59
- # Capture the first image URL found
60
- if not image_url:
61
- image_url = part.get("image_url", {}).get("url")
62
- # Use the official Claude "Human:" prefix for the prompt
63
- prompt_parts.append(f"Human: {text_content.strip()}")
64
- else:
65
- # Standard text-only message
66
- prompt_parts.append(f"Human: {msg.content}")
67
- # --- VISION SUPPORT END ---
68
- elif msg.role == "assistant":
69
- # Use the official Claude "Assistant:" prefix for the prompt
70
- prompt_parts.append(f"Assistant: {msg.content}")
71
-
72
- # Add the final "Assistant:" turn to prompt the model for its response.
73
- prompt_parts.append("Assistant:")
74
 
75
- payload["prompt"] = "\n\n".join(prompt_parts)
76
-
77
- if system_prompt:
78
- payload["system_prompt"] = system_prompt
79
-
80
- if image_url:
81
- # Add the captured image URL to the payload for the vision model
82
- payload["image"] = image_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Map common OpenAI parameters to Replicate equivalents
85
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
@@ -123,15 +137,13 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
123
  current_event = line[len("event:"):].strip()
124
  elif line.startswith("data:"):
125
  data = line[len("data:"):].strip()
126
-
127
  if current_event == "output":
128
  if data:
129
  chunk = {
130
- "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
131
  "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
132
  }
133
  yield json.dumps(chunk)
134
-
135
  elif current_event == "done":
136
  break
137
  except httpx.ReadTimeout:
@@ -139,7 +151,7 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
139
  return
140
 
141
  final_chunk = {
142
- "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
143
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
144
  }
145
  yield json.dumps(final_chunk)
@@ -158,7 +170,8 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
158
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
159
 
160
  replicate_id = SUPPORTED_MODELS[request.model]
161
- replicate_input = prepare_replicate_input(request)
 
162
 
163
  if request.stream:
164
  return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input), media_type="text/event-stream")
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="6.0.0 (Claude Vision Enabled)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
 
30
 
31
  # --- Supported Models ---
32
  SUPPORTED_MODELS = {
33
+ # Text Models
34
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
+ # Anthropic Claude Models (Vision Enabled)
36
+ "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
37
+ "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
38
+ # Other Vision Model (uses different input format)
39
+ "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
40
  }
41
 
42
  # --- Core Logic ---
43
+ def prepare_replicate_input(request: OpenAIChatCompletionRequest, replicate_id: str) -> Dict[str, Any]:
44
  """
45
+ Formats the input for the Replicate API based on the model's requirements.
46
+ - Modern Claude models accept the 'messages' array directly for multimodal input.
47
+ - Other models may require a flattened 'prompt' string and a separate 'image' field.
48
  """
49
  payload = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # --- MODEL-AWARE PAYLOAD PREPARATION ---
52
+ if "anthropic/claude" in replicate_id:
53
+ # These models support the OpenAI-like 'messages' array directly.
54
+ # This is the correct way to handle multimodal (image) inputs for Claude.
55
+ messages_for_payload = []
56
+ system_prompt = None
57
+ for msg in request.messages:
58
+ if msg.role == "system":
59
+ system_prompt = str(msg.content)
60
+ else:
61
+ # Convert Pydantic model to dict and add to the list
62
+ messages_for_payload.append(msg.dict())
63
+
64
+ payload["messages"] = messages_for_payload
65
+ if system_prompt:
66
+ payload["system_prompt"] = system_prompt
67
+
68
+ else:
69
+ # Fallback for models that require a flattened prompt string (e.g., Llama, Llava)
70
+ prompt_parts = []
71
+ image_input = None
72
+ for msg in request.messages:
73
+ if msg.role == "system":
74
+ # System prompts are handled differently or prepended by the user
75
+ # for these models, often as part of the main prompt.
76
+ # For simplicity, we'll place it at the beginning.
77
+ prompt_parts.insert(0, str(msg.content))
78
+ elif msg.role == "assistant":
79
+ prompt_parts.append(f"Assistant: {msg.content}")
80
+ elif msg.role == "user":
81
+ user_text_content = ""
82
+ if isinstance(msg.content, list):
83
+ for item in msg.content:
84
+ if item.get("type") == "text":
85
+ user_text_content += item.get("text", "")
86
+ elif item.get("type") == "image_url":
87
+ image_url_data = item.get("image_url", {})
88
+ image_input = image_url_data.get("url")
89
+ else:
90
+ user_text_content = str(msg.content)
91
+ prompt_parts.append(f"User: {user_text_content}")
92
+
93
+ prompt_parts.append("Assistant:")
94
+ payload["prompt"] = "\n\n".join(prompt_parts)
95
+ if image_input:
96
+ payload["image"] = image_input
97
 
98
  # Map common OpenAI parameters to Replicate equivalents
99
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
 
137
  current_event = line[len("event:"):].strip()
138
  elif line.startswith("data:"):
139
  data = line[len("data:"):].strip()
 
140
  if current_event == "output":
141
  if data:
142
  chunk = {
143
+ "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_id,
144
  "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
145
  }
146
  yield json.dumps(chunk)
 
147
  elif current_event == "done":
148
  break
149
  except httpx.ReadTimeout:
 
151
  return
152
 
153
  final_chunk = {
154
+ "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_id,
155
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
156
  }
157
  yield json.dumps(final_chunk)
 
170
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
171
 
172
  replicate_id = SUPPORTED_MODELS[request.model]
173
+ # Pass the replicate_id to the prepare function so it knows which format to use
174
+ replicate_input = prepare_replicate_input(request, replicate_id)
175
 
176
  if request.stream:
177
  return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input), media_type="text/event-stream")