KaThaNg commited on
Commit
9e412e2
·
verified ·
1 Parent(s): fc789cc

Update proxy_server.py

Browse files
Files changed (1) hide show
  1. proxy_server.py +96 -165
proxy_server.py CHANGED
@@ -12,32 +12,23 @@ from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from loguru import logger
14
  from typing import AsyncGenerator, Set, Optional, Dict, Any, List
15
- # from urllib.parse import urlparse # Removed: No longer needed for logging
16
 
17
  # --- Logging Configuration ---
18
  logger.remove()
19
  log_level = os.getenv("LOG_LEVEL", "INFO").upper()
20
- # Ensure DEBUG level enables detailed logging
21
  logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
22
 
23
  # --- Environment Variable Configuration ---
24
- # Target OpenAI API endpoint (can be OpenAI's official API or another compatible proxy like chat-api)
25
  OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions")
26
- # API Key for authenticating with the target OpenAI endpoint (if required by the endpoint)
27
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
- # List of valid API keys for this proxy service, comma-separated
29
- PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "") # Example: "key1,key2,mysecretkey"
30
  VALID_API_KEYS: Set[str] = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip())
31
- # httpx client timeouts
32
  CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", 5.0))
33
- # Increased read timeout for potentially long LLM responses
34
  READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", 180.0))
35
  WRITE_TIMEOUT = float(os.getenv("WRITE_TIMEOUT", 30.0))
36
  POOL_TIMEOUT = float(os.getenv("POOL_TIMEOUT", 5.0))
37
- # httpx client connection pool limits
38
  MAX_CONNECTIONS = int(os.getenv("MAX_CONNECTIONS", 100))
39
  MAX_KEEPALIVE = int(os.getenv("MAX_KEEPALIVE", 20))
40
- # Optional outbound proxy (e.g., http://user:pass@host:port or socks5://host:port)
41
  HTTP_PROXY = os.getenv("HTTP_PROXY")
42
 
43
  # --- Global httpx Client ---
@@ -51,11 +42,10 @@ async def lifespan(app: FastAPI):
51
  timeout_config = httpx.Timeout(connect=CONNECT_TIMEOUT, read=READ_TIMEOUT, write=WRITE_TIMEOUT, pool=POOL_TIMEOUT)
52
  proxy_config = {"http://": HTTP_PROXY, "https://": HTTP_PROXY} if HTTP_PROXY else None
53
 
54
- # Completely hide target endpoint info from logs
55
- logger.info("Initializing httpx client for upstream requests.") # Generic message
56
 
57
  if proxy_config:
58
- logger.info(f"Using outbound proxy: {HTTP_PROXY}") # Proxy URL might still be sensitive
59
  if not OPENAI_API_KEY:
60
  logger.warning("OPENAI_API_KEY is not set. Requests to the target endpoint might fail if it requires authentication.")
61
  if not VALID_API_KEYS:
@@ -65,7 +55,7 @@ async def lifespan(app: FastAPI):
65
  limits=limits,
66
  timeout=timeout_config,
67
  proxies=proxy_config,
68
- http2=True, # Enable HTTP/2 if supported by the server
69
  follow_redirects=True
70
  )
71
  yield
@@ -83,14 +73,14 @@ app = FastAPI(
83
  # --- CORS Middleware ---
84
  app.add_middleware(
85
  CORSMiddleware,
86
- allow_origins=["*"], # Allow all origins, or specify allowed origins
87
  allow_credentials=True,
88
- allow_methods=["*"], # Allow all methods (GET, POST, etc.)
89
- allow_headers=["*"], # Allow all headers
90
  )
91
 
92
  # --- API Key Authentication ---
93
- API_KEY_NAME = "X-API-Key" # Standard header name for API keys
94
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
95
 
96
  async def get_api_key(key: Optional[str] = Security(api_key_header)) -> str:
@@ -110,60 +100,42 @@ async def get_api_key(key: Optional[str] = Security(api_key_header)) -> str:
110
 
111
  # --- Format Conversion Logic ---
112
 
113
- # --- FIX: Removed client_api_key parameter and the "user" field addition ---
114
  def claude_request_to_openai_payload(claude_request: Dict[str, Any]) -> Dict[str, Any]:
115
- """
116
- Converts a Claude API request body to OpenAI API format.
117
- """
118
  messages = []
119
  system_prompt = claude_request.get("system")
120
  if system_prompt:
121
- # Ensure system prompt content is a string
122
  if isinstance(system_prompt, list):
123
- # Combine text blocks if system prompt is a list (like in Claude's format)
124
  system_content = "\n".join(block.get("text", "") for block in system_prompt if block.get("type") == "text")
125
  elif isinstance(system_prompt, str):
126
  system_content = system_prompt
127
- else:
128
- system_content = "" # Handle other unexpected types if necessary
129
  if system_content:
130
  messages.append({"role": "system", "content": system_content})
131
 
132
-
133
  for msg in claude_request.get("messages", []):
134
  role = msg.get("role")
135
  content_parts = []
136
- # Claude content is a list of blocks
137
  if isinstance(msg.get("content"), list):
138
  for block in msg.get("content", []):
139
  if block.get("type") == "text":
140
  content_parts.append(block.get("text", ""))
141
- elif isinstance(msg.get("content"), str): # Handle simple string content if provided
142
  content_parts.append(msg.get("content"))
143
-
144
  if role and content_parts:
145
- # Combine multiple text blocks into one message for OpenAI
146
  messages.append({"role": role, "content": "\n".join(content_parts)})
147
 
148
- # Map parameters (add more mappings as needed)
149
  openai_payload = {
150
- "model": claude_request.get("model", "gpt-3.5-turbo"), # Use model from request or default
151
  "messages": messages,
152
  "stream": claude_request.get("stream", False),
153
- # --- FIX: Removed the "user" field ---
154
- # Optional parameters - only include if present in Claude request
155
  **({ "max_tokens": v } if (v := claude_request.get("max_tokens")) is not None else {}),
156
  **({ "temperature": v } if (v := claude_request.get("temperature")) is not None else {}),
157
  **({ "top_p": v } if (v := claude_request.get("top_p")) is not None else {}),
158
  **({ "stop": v } if (v := claude_request.get("stop_sequences")) is not None else {}),
159
- # Add other relevant parameter mappings here (e.g., presence_penalty, frequency_penalty)
160
  }
161
-
162
- # logger.debug("Converted Claude request to OpenAI payload.") # Keep this simple
163
  return openai_payload
164
- # --- End Fix ---
165
 
166
- # ... (openai_response_to_claude_response and stream_openai_response_to_claude_events remain the same) ...
167
  def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_request_id: str) -> Dict[str, Any]:
168
  """Converts a non-streaming OpenAI response to Claude API format."""
169
  try:
@@ -171,34 +143,25 @@ def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_r
171
  message = choice.get("message", {})
172
  content = message.get("content", "")
173
  role = message.get("role", "assistant")
174
- finish_reason = choice.get("finish_reason", "stop") # Map OpenAI finish reasons
175
 
176
- # Map OpenAI finish reasons to Claude stop reasons
177
  stop_reason_map = {
178
- "stop": "end_turn",
179
- "length": "max_tokens",
180
- "function_call": "tool_use", # Or handle appropriately if tools are used
181
- "content_filter": "stop_sequence", # Approximate mapping
182
- "null": "stop_sequence", # Approximate mapping
183
  }
184
- claude_stop_reason = stop_reason_map.get(finish_reason, "stop_sequence") # Default if unknown
185
 
186
  usage = openai_response.get("usage", {})
187
  prompt_tokens = usage.get("prompt_tokens", 0)
188
  completion_tokens = usage.get("completion_tokens", 0)
189
 
190
  claude_response = {
191
- "id": openai_response.get("id", claude_request_id), # Use OpenAI ID or original request ID
192
- "type": "message",
193
- "role": role,
194
- "content": [{"type": "text", "text": content or ""}], # Ensure content is not None
195
- "model": openai_response.get("model", "claude-proxy-model"), # Model used by OpenAI
196
- "stop_reason": claude_stop_reason,
197
- "stop_sequence": None, # OpenAI doesn't explicitly return the sequence
198
- "usage": {
199
- "input_tokens": prompt_tokens,
200
- "output_tokens": completion_tokens,
201
- },
202
  }
203
  logger.debug(f"[{claude_request_id}] Converted non-streaming OpenAI response to Claude format.")
204
  return claude_response
@@ -208,18 +171,18 @@ def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_r
208
 
209
  async def stream_openai_response_to_claude_events(openai_response: httpx.Response, claude_request_id: str, requested_model: str) -> AsyncGenerator[str, None]:
210
  """Converts an OpenAI SSE stream to Claude API SSE format."""
211
- message_id = claude_request_id # Use the original request ID for consistency
212
- accumulated_content_len = 0 # Track length instead of full content
213
  openai_finish_reason = None
214
- input_tokens = 0 # Will be updated if usage info is sent
215
- output_tokens = 0 # Will be updated if usage info is sent
216
  last_ping_time = time.time()
217
 
218
  logger.debug(f"[{message_id}] Starting Claude SSE stream conversion.")
219
 
220
  # 1. Send message_start event
221
- yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': requested_model, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}})}\n\n"
222
- # 2. Send content_block_start event for the text block
223
  yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
224
  # 3. Send initial ping
225
  yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n"
@@ -227,95 +190,113 @@ async def stream_openai_response_to_claude_events(openai_response: httpx.Respons
227
  try:
228
  async for line in openai_response.aiter_lines():
229
  line = line.strip()
230
- if not line:
231
- continue # Skip empty lines
232
 
233
  if line.startswith("data:"):
234
  data_str = line[len("data: "):].strip()
235
  if data_str == "[DONE]":
236
  logger.debug(f"[{message_id}] Received [DONE] marker from OpenAI stream.")
237
- break # End of OpenAI stream
238
 
239
  try:
240
  data = json.loads(data_str)
241
  choices = data.get("choices", [])
242
- if not choices:
243
- continue
 
 
 
 
 
 
244
 
245
  delta = choices[0].get("delta", {})
246
  content_chunk = delta.get("content")
247
 
248
- # Check for finish reason in the chunk
249
  if choices[0].get("finish_reason"):
250
  openai_finish_reason = choices[0].get("finish_reason")
251
  logger.debug(f"[{message_id}] Received OpenAI finish_reason: {openai_finish_reason}")
252
 
253
- # Check for usage update (some models send it at the end)
254
- usage_update = data.get("usage")
255
- if usage_update:
256
- input_tokens = usage_update.get("prompt_tokens", input_tokens)
257
  output_tokens = usage_update.get("completion_tokens", output_tokens)
258
- logger.debug(f"[{message_id}] Received usage update: input={input_tokens}, output={output_tokens}")
 
259
 
260
  if content_chunk:
261
  accumulated_content_len += len(content_chunk)
262
- # 4. Send content_block_delta for the text chunk
 
 
 
 
263
  yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_chunk}})}\n\n"
264
 
265
  except json.JSONDecodeError:
266
  logger.warning(f"[{message_id}] Could not decode JSON from stream line: {data_str}")
267
- continue
268
  except Exception as e:
269
  logger.error(f"[{message_id}] Error processing stream data chunk: {e}")
270
- continue # Skip this chunk
271
 
272
- # Send periodic pings
273
  current_time = time.time()
274
- if current_time - last_ping_time >= 10: # Ping every 10 seconds
275
  yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n"
276
  last_ping_time = current_time
277
 
278
  except httpx.ReadTimeout:
279
  logger.error(f"[{message_id}] Timeout reading from OpenAI stream.")
280
- openai_finish_reason = "error_timeout" # Custom reason
281
  yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Proxy timed out waiting for OpenAI stream'}})}\n\n"
282
  except Exception as e:
283
  logger.exception(f"[{message_id}] Unexpected error during stream processing: {e}")
284
- openai_finish_reason = "error_exception" # Custom reason
285
  yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'internal_server_error', 'message': f'Proxy stream processing error: {e}'}})}\n\n"
286
  finally:
287
- # Map OpenAI finish reason to Claude stop reason
288
  stop_reason_map = {
289
- "stop": "end_turn",
290
- "length": "max_tokens",
291
- "function_call": "tool_use",
292
- "content_filter": "stop_sequence",
293
- "null": "stop_sequence",
294
- "error_timeout": "error", # Custom mapping
295
- "error_exception": "error", # Custom mapping
296
  }
297
- claude_stop_reason = stop_reason_map.get(openai_finish_reason, "stop_sequence") # Default
298
 
299
  logger.debug(f"[{message_id}] Stream finished. OpenAI finish reason: {openai_finish_reason}, mapped Claude stop reason: {claude_stop_reason}")
300
 
301
  # 5. Send content_block_stop
302
  yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n"
303
 
304
- # 6. Send message_delta with final usage and stop reason
305
  final_delta = {
306
  'type': 'message_delta',
307
  'delta': {
308
  'stop_reason': claude_stop_reason,
309
- 'stop_sequence': None # OpenAI doesn't provide this
310
- },
311
- 'usage': {
312
- 'output_tokens': output_tokens if output_tokens > 0 else (accumulated_content_len // 4) # Rough estimate
313
  }
 
314
  }
315
  yield f"event: message_delta\ndata: {json.dumps(final_delta)}\n\n"
316
 
317
- # 7. Send message_stop
318
- yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  logger.info(f"[{message_id}] Completed sending Claude SSE stream.")
320
 
321
 
@@ -323,26 +304,18 @@ def create_error_response(status_code: int, error_type: str, message: str) -> JS
323
  """Creates a JSONResponse with a Claude-like error structure."""
324
  return JSONResponse(
325
  status_code=status_code,
326
- content={
327
- "type": "error",
328
- "error": {
329
- "type": error_type,
330
- "message": message
331
- }
332
- }
333
  )
334
 
335
  # --- Main Proxy Endpoint ---
336
- # --- FIX: Removed client_api_key parameter from signature as it's not used in payload conversion anymore ---
337
  @app.post("/v1/messages", dependencies=[Depends(get_api_key)])
338
  async def proxy_claude_to_openai(request: Request):
339
- # --- End Fix ---
340
  """
341
  Receives a Claude-formatted request, proxies it to OpenAI,
342
  and returns a Claude-formatted response.
343
  Requires a valid API key via the X-API-Key header.
344
  """
345
- request_id = f"msg_{uuid.uuid4().hex[:24]}" # Generate a unique ID for logging/tracking
346
  try:
347
  claude_request_data = await request.json()
348
  logger.info(f"[{request_id}] Received request. Stream: {claude_request_data.get('stream', False)}. Model: {claude_request_data.get('model')}")
@@ -350,19 +323,15 @@ async def proxy_claude_to_openai(request: Request):
350
  logger.error(f"[{request_id}] Invalid JSON received in request body.")
351
  return create_error_response(400, "invalid_request_error", "Invalid JSON data in request body.")
352
 
353
- # Convert request format
354
  try:
355
- # --- FIX: Call conversion function without client_api_key ---
356
  openai_payload = claude_request_to_openai_payload(claude_request_data)
357
- # --- End Fix ---
358
  except Exception as e:
359
  logger.error(f"[{request_id}] Failed to convert Claude request to OpenAI format: {e}")
360
  return create_error_response(400, "invalid_request_error", f"Failed to process request data: {e}")
361
 
362
  is_streaming = openai_payload.get("stream", False)
363
- requested_model = openai_payload.get("model", "unknown_model") # For response generation
364
 
365
- # Prepare headers for the target OpenAI endpoint
366
  target_headers = { "Content-Type": "application/json" }
367
  if OPENAI_API_KEY:
368
  target_headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
@@ -370,56 +339,35 @@ async def proxy_claude_to_openai(request: Request):
370
  else:
371
  logger.debug(f"[{request_id}] No OPENAI_API_KEY configured for upstream request.")
372
 
373
- # Log headers and payload if log level is DEBUG
374
- if logger.level("DEBUG").no >= logger.level(log_level).no: # Check if DEBUG is enabled
375
  logged_headers = target_headers.copy()
376
- if "Authorization" in logged_headers:
377
- logged_headers["Authorization"] = "Bearer [REDACTED]"
378
  logger.debug(f"[{request_id}] Sending request to upstream API.")
379
  logger.debug(f"[{request_id}] Upstream Headers: {json.dumps(logged_headers)}")
380
  try:
381
  payload_str = json.dumps(openai_payload, indent=2)
382
  max_log_len = 1024
383
- if len(payload_str) > max_log_len:
384
- logger.debug(f"[{request_id}] Upstream Payload (truncated): {payload_str[:max_log_len]}...")
385
- else:
386
- logger.debug(f"[{request_id}] Upstream Payload: {payload_str}")
387
  except Exception as log_e:
388
  logger.warning(f"[{request_id}] Could not serialize or log upstream payload: {log_e}")
389
  else:
390
- logger.debug(f"[{request_id}] Sending request to upstream API...") # Generic message for INFO level
391
-
392
 
393
  try:
394
- # Build the request to the target endpoint
395
- target_request = client.build_request(
396
- method="POST",
397
- url=OPENAI_API_ENDPOINT,
398
- headers=target_headers, # Use the prepared headers
399
- json=openai_payload,
400
- )
401
-
402
- # Send the request and handle the response
403
  response = await client.send(target_request, stream=is_streaming)
 
404
 
405
- # Check for HTTP errors from the target endpoint *before* processing the body
406
- response.raise_for_status() # Raises exception for 4xx/5xx errors
407
-
408
- # Process the response based on streaming or non-streaming
409
  if is_streaming:
410
  logger.info(f"[{request_id}] Upstream response is streaming. Starting SSE conversion.")
411
  return StreamingResponse(
412
  stream_openai_response_to_claude_events(response, request_id, requested_model),
413
  media_type="text/event-stream",
414
- headers={
415
- "X-Content-Type-Options": "nosniff",
416
- "Cache-Control": "no-cache",
417
- "Connection": "keep-alive",
418
- }
419
  )
420
  else:
421
  logger.info(f"[{request_id}] Upstream response is non-streaming. Converting.")
422
- openai_response_data = response.json() # No await needed
423
  logger.debug(f"[{request_id}] Received non-streaming response from upstream.")
424
  try:
425
  claude_response_data = openai_response_to_claude_response(openai_response_data, request_id)
@@ -431,31 +379,22 @@ async def proxy_claude_to_openai(request: Request):
431
  logger.exception(f"[{request_id}] Unexpected error converting non-streaming response: {e}")
432
  return create_error_response(500, "internal_server_error", "Unexpected error processing upstream response.")
433
 
434
-
435
- # --- Error Handling for Target API Request ---
436
  except httpx.HTTPStatusError as e:
437
  status_code = e.response.status_code
438
  error_detail_text = "[Could not decode error response]"
439
  try:
440
- error_detail = e.response.json()
441
- error_detail_text = json.dumps(error_detail)
442
  except json.JSONDecodeError:
443
- try:
444
- error_detail_text = e.response.text
445
- except Exception:
446
- logger.warning(f"[{request_id}] Could not read error response body as text.")
447
-
448
  logger.error(f"[{request_id}] HTTP error from target endpoint ({status_code}). Response snippet: {error_detail_text[:200]}...")
449
-
450
  if status_code == 400: err_type, msg = "invalid_request_error", f"Upstream API Bad Request ({status_code})."
451
  elif status_code == 401: err_type, msg = "authentication_error", f"Authentication failed with upstream API ({status_code})."
452
  elif status_code == 403: err_type, msg = "permission_error", f"Forbidden by upstream API ({status_code})."
453
  elif status_code == 429: err_type, msg = "rate_limit_error", f"Rate limit exceeded with upstream API ({status_code})."
454
  elif status_code >= 500: err_type, msg = "api_error", f"Upstream API unavailable or encountered an error ({status_code})."
455
  else: err_type, msg = "api_error", f"Received unexpected error from upstream API ({status_code})."
456
-
457
  return create_error_response(status_code, err_type, msg)
458
-
459
  except httpx.TimeoutException:
460
  logger.error(f"[{request_id}] Request to target endpoint timed out ({READ_TIMEOUT}s).")
461
  return create_error_response(504, "api_error", "Gateway Timeout: Request to upstream API timed out.")
@@ -479,14 +418,12 @@ if __name__ == "__main__":
479
  from dotenv import load_dotenv
480
  load_dotenv()
481
  logger.info("Loaded environment variables from .env file (if present).")
482
- # Reload config vars after loading .env
483
  PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "")
484
  VALID_API_KEYS = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip())
485
  OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions")
486
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
487
  HTTP_PROXY = os.getenv("HTTP_PROXY")
488
  log_level = os.getenv("LOG_LEVEL", "INFO").upper()
489
- # Reconfigure logger if level changed
490
  logger.remove()
491
  logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
492
  logger.info(f"Log level set to: {log_level}")
@@ -497,12 +434,6 @@ if __name__ == "__main__":
497
  port = int(os.getenv("PORT", 7860))
498
  host = os.getenv("HOST", "0.0.0.0")
499
  log_config_level = log_level.lower() if log_level in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "TRACE"] else "info"
500
-
501
  logger.info(f"Starting Uvicorn server on {host}:{port}")
502
- uvicorn.run(
503
- "proxy_server:app",
504
- host=host,
505
- port=port,
506
- reload=True,
507
- log_level=log_config_level
508
- )
 
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from loguru import logger
14
  from typing import AsyncGenerator, Set, Optional, Dict, Any, List
 
15
 
16
  # --- Logging Configuration ---
17
  logger.remove()
18
  log_level = os.getenv("LOG_LEVEL", "INFO").upper()
 
19
  logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
20
 
21
  # --- Environment Variable Configuration ---
 
22
  OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions")
 
23
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
24
+ PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "")
 
25
  VALID_API_KEYS: Set[str] = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip())
 
26
  CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", 5.0))
 
27
  READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", 180.0))
28
  WRITE_TIMEOUT = float(os.getenv("WRITE_TIMEOUT", 30.0))
29
  POOL_TIMEOUT = float(os.getenv("POOL_TIMEOUT", 5.0))
 
30
  MAX_CONNECTIONS = int(os.getenv("MAX_CONNECTIONS", 100))
31
  MAX_KEEPALIVE = int(os.getenv("MAX_KEEPALIVE", 20))
 
32
  HTTP_PROXY = os.getenv("HTTP_PROXY")
33
 
34
  # --- Global httpx Client ---
 
42
  timeout_config = httpx.Timeout(connect=CONNECT_TIMEOUT, read=READ_TIMEOUT, write=WRITE_TIMEOUT, pool=POOL_TIMEOUT)
43
  proxy_config = {"http://": HTTP_PROXY, "https://": HTTP_PROXY} if HTTP_PROXY else None
44
 
45
+ logger.info("Initializing httpx client for upstream requests.")
 
46
 
47
  if proxy_config:
48
+ logger.info(f"Using outbound proxy: {HTTP_PROXY}")
49
  if not OPENAI_API_KEY:
50
  logger.warning("OPENAI_API_KEY is not set. Requests to the target endpoint might fail if it requires authentication.")
51
  if not VALID_API_KEYS:
 
55
  limits=limits,
56
  timeout=timeout_config,
57
  proxies=proxy_config,
58
+ http2=True,
59
  follow_redirects=True
60
  )
61
  yield
 
73
  # --- CORS Middleware ---
74
  app.add_middleware(
75
  CORSMiddleware,
76
+ allow_origins=["*"],
77
  allow_credentials=True,
78
+ allow_methods=["*"],
79
+ allow_headers=["*"],
80
  )
81
 
82
  # --- API Key Authentication ---
83
+ API_KEY_NAME = "X-API-Key"
84
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
85
 
86
  async def get_api_key(key: Optional[str] = Security(api_key_header)) -> str:
 
100
 
101
  # --- Format Conversion Logic ---
102
 
 
103
  def claude_request_to_openai_payload(claude_request: Dict[str, Any]) -> Dict[str, Any]:
104
+ """Converts a Claude API request body to OpenAI API format."""
 
 
105
  messages = []
106
  system_prompt = claude_request.get("system")
107
  if system_prompt:
108
+ system_content = ""
109
  if isinstance(system_prompt, list):
 
110
  system_content = "\n".join(block.get("text", "") for block in system_prompt if block.get("type") == "text")
111
  elif isinstance(system_prompt, str):
112
  system_content = system_prompt
 
 
113
  if system_content:
114
  messages.append({"role": "system", "content": system_content})
115
 
 
116
  for msg in claude_request.get("messages", []):
117
  role = msg.get("role")
118
  content_parts = []
 
119
  if isinstance(msg.get("content"), list):
120
  for block in msg.get("content", []):
121
  if block.get("type") == "text":
122
  content_parts.append(block.get("text", ""))
123
+ elif isinstance(msg.get("content"), str):
124
  content_parts.append(msg.get("content"))
 
125
  if role and content_parts:
 
126
  messages.append({"role": role, "content": "\n".join(content_parts)})
127
 
 
128
  openai_payload = {
129
+ "model": claude_request.get("model", "gpt-3.5-turbo"),
130
  "messages": messages,
131
  "stream": claude_request.get("stream", False),
 
 
132
  **({ "max_tokens": v } if (v := claude_request.get("max_tokens")) is not None else {}),
133
  **({ "temperature": v } if (v := claude_request.get("temperature")) is not None else {}),
134
  **({ "top_p": v } if (v := claude_request.get("top_p")) is not None else {}),
135
  **({ "stop": v } if (v := claude_request.get("stop_sequences")) is not None else {}),
 
136
  }
 
 
137
  return openai_payload
 
138
 
 
139
  def openai_response_to_claude_response(openai_response: Dict[str, Any], claude_request_id: str) -> Dict[str, Any]:
140
  """Converts a non-streaming OpenAI response to Claude API format."""
141
  try:
 
143
  message = choice.get("message", {})
144
  content = message.get("content", "")
145
  role = message.get("role", "assistant")
146
+ finish_reason = choice.get("finish_reason", "stop")
147
 
 
148
  stop_reason_map = {
149
+ "stop": "end_turn", "length": "max_tokens", "function_call": "tool_use",
150
+ "content_filter": "stop_sequence", "null": "stop_sequence",
 
 
 
151
  }
152
+ claude_stop_reason = stop_reason_map.get(finish_reason, "stop_sequence")
153
 
154
  usage = openai_response.get("usage", {})
155
  prompt_tokens = usage.get("prompt_tokens", 0)
156
  completion_tokens = usage.get("completion_tokens", 0)
157
 
158
  claude_response = {
159
+ "id": openai_response.get("id", claude_request_id),
160
+ "type": "message", "role": role,
161
+ "content": [{"type": "text", "text": content or ""}],
162
+ "model": openai_response.get("model", "claude-proxy-model"),
163
+ "stop_reason": claude_stop_reason, "stop_sequence": None,
164
+ "usage": { "input_tokens": prompt_tokens, "output_tokens": completion_tokens },
 
 
 
 
 
165
  }
166
  logger.debug(f"[{claude_request_id}] Converted non-streaming OpenAI response to Claude format.")
167
  return claude_response
 
171
 
172
  async def stream_openai_response_to_claude_events(openai_response: httpx.Response, claude_request_id: str, requested_model: str) -> AsyncGenerator[str, None]:
173
  """Converts an OpenAI SSE stream to Claude API SSE format."""
174
+ message_id = claude_request_id
175
+ accumulated_content_len = 0
176
  openai_finish_reason = None
177
+ input_tokens = 0 # Try to capture this
178
+ output_tokens = 0
179
  last_ping_time = time.time()
180
 
181
  logger.debug(f"[{message_id}] Starting Claude SSE stream conversion.")
182
 
183
  # 1. Send message_start event
184
+ yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': requested_model, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}})}\n\n" # Initial usage is 0
185
+ # 2. Send content_block_start event
186
  yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
187
  # 3. Send initial ping
188
  yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n"
 
190
  try:
191
  async for line in openai_response.aiter_lines():
192
  line = line.strip()
193
+ if not line: continue
 
194
 
195
  if line.startswith("data:"):
196
  data_str = line[len("data: "):].strip()
197
  if data_str == "[DONE]":
198
  logger.debug(f"[{message_id}] Received [DONE] marker from OpenAI stream.")
199
+ break
200
 
201
  try:
202
  data = json.loads(data_str)
203
  choices = data.get("choices", [])
204
+ if not choices: continue
205
+
206
+ # --- Try to capture input tokens if sent early ---
207
+ usage_update = data.get("usage")
208
+ if usage_update and usage_update.get("prompt_tokens") is not None and input_tokens == 0:
209
+ input_tokens = usage_update.get("prompt_tokens", 0)
210
+ logger.debug(f"[{message_id}] Captured input_tokens: {input_tokens}")
211
+ # ---
212
 
213
  delta = choices[0].get("delta", {})
214
  content_chunk = delta.get("content")
215
 
 
216
  if choices[0].get("finish_reason"):
217
  openai_finish_reason = choices[0].get("finish_reason")
218
  logger.debug(f"[{message_id}] Received OpenAI finish_reason: {openai_finish_reason}")
219
 
220
+ # Update output tokens based on usage update if available
221
+ if usage_update and usage_update.get("completion_tokens") is not None:
 
 
222
  output_tokens = usage_update.get("completion_tokens", output_tokens)
223
+ logger.debug(f"[{message_id}] Received completion_tokens update: {output_tokens}")
224
+
225
 
226
  if content_chunk:
227
  accumulated_content_len += len(content_chunk)
228
+ # Estimate output tokens if not provided by usage update
229
+ if not (usage_update and usage_update.get("completion_tokens") is not None):
230
+ output_tokens += 1 # Simple increment per chunk as fallback
231
+
232
+ # 4. Send content_block_delta
233
  yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_chunk}})}\n\n"
234
 
235
  except json.JSONDecodeError:
236
  logger.warning(f"[{message_id}] Could not decode JSON from stream line: {data_str}")
 
237
  except Exception as e:
238
  logger.error(f"[{message_id}] Error processing stream data chunk: {e}")
 
239
 
 
240
  current_time = time.time()
241
+ if current_time - last_ping_time >= 10:
242
  yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n"
243
  last_ping_time = current_time
244
 
245
  except httpx.ReadTimeout:
246
  logger.error(f"[{message_id}] Timeout reading from OpenAI stream.")
247
+ openai_finish_reason = "error_timeout"
248
  yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Proxy timed out waiting for OpenAI stream'}})}\n\n"
249
  except Exception as e:
250
  logger.exception(f"[{message_id}] Unexpected error during stream processing: {e}")
251
+ openai_finish_reason = "error_exception"
252
  yield f"event: error\ndata: {json.dumps({'type': 'error', 'error': {'type': 'internal_server_error', 'message': f'Proxy stream processing error: {e}'}})}\n\n"
253
  finally:
 
254
  stop_reason_map = {
255
+ "stop": "end_turn", "length": "max_tokens", "function_call": "tool_use",
256
+ "content_filter": "stop_sequence", "null": "stop_sequence",
257
+ "error_timeout": "error", "error_exception": "error",
 
 
 
 
258
  }
259
+ claude_stop_reason = stop_reason_map.get(openai_finish_reason, "stop_sequence")
260
 
261
  logger.debug(f"[{message_id}] Stream finished. OpenAI finish reason: {openai_finish_reason}, mapped Claude stop reason: {claude_stop_reason}")
262
 
263
  # 5. Send content_block_stop
264
  yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n"
265
 
266
+ # 6. Send message_delta with final stop reason ONLY
267
  final_delta = {
268
  'type': 'message_delta',
269
  'delta': {
270
  'stop_reason': claude_stop_reason,
271
+ 'stop_sequence': None
 
 
 
272
  }
273
+ # Removed usage from here
274
  }
275
  yield f"event: message_delta\ndata: {json.dumps(final_delta)}\n\n"
276
 
277
+ # 7. Send message_stop (including final usage)
278
+ # --- FIX: Add usage info to the message_stop event ---
279
+ final_stop_event_data = {
280
+ 'type': 'message_stop',
281
+ # Claude API v1 examples sometimes show a nested 'message' object here,
282
+ # but let's try putting usage directly under the event first for simplicity,
283
+ # similar to how message_start includes a 'message' object.
284
+ # If this doesn't work, we might need the nested 'message' structure.
285
+ 'amazon-bedrock-invocationMetrics': { # Mimic Bedrock's potential structure for usage
286
+ 'inputTokenCount': input_tokens,
287
+ 'outputTokenCount': output_tokens if output_tokens > 0 else (accumulated_content_len // 4), # Use estimate if needed
288
+ 'invocationLatency': 0, # Placeholder
289
+ 'firstByteLatency': 0 # Placeholder
290
+ }
291
+ # Alternative simpler structure (if the above fails):
292
+ # 'usage': {
293
+ # 'input_tokens': input_tokens,
294
+ # 'output_tokens': output_tokens if output_tokens > 0 else (accumulated_content_len // 4)
295
+ # }
296
+ }
297
+ yield f"event: message_stop\ndata: {json.dumps(final_stop_event_data)}\n\n"
298
+ # --- End Fix ---
299
+
300
  logger.info(f"[{message_id}] Completed sending Claude SSE stream.")
301
 
302
 
 
304
  """Creates a JSONResponse with a Claude-like error structure."""
305
  return JSONResponse(
306
  status_code=status_code,
307
+ content={"type": "error", "error": {"type": error_type, "message": message}}
 
 
 
 
 
 
308
  )
309
 
310
  # --- Main Proxy Endpoint ---
 
311
  @app.post("/v1/messages", dependencies=[Depends(get_api_key)])
312
  async def proxy_claude_to_openai(request: Request):
 
313
  """
314
  Receives a Claude-formatted request, proxies it to OpenAI,
315
  and returns a Claude-formatted response.
316
  Requires a valid API key via the X-API-Key header.
317
  """
318
+ request_id = f"msg_{uuid.uuid4().hex[:24]}"
319
  try:
320
  claude_request_data = await request.json()
321
  logger.info(f"[{request_id}] Received request. Stream: {claude_request_data.get('stream', False)}. Model: {claude_request_data.get('model')}")
 
323
  logger.error(f"[{request_id}] Invalid JSON received in request body.")
324
  return create_error_response(400, "invalid_request_error", "Invalid JSON data in request body.")
325
 
 
326
  try:
 
327
  openai_payload = claude_request_to_openai_payload(claude_request_data)
 
328
  except Exception as e:
329
  logger.error(f"[{request_id}] Failed to convert Claude request to OpenAI format: {e}")
330
  return create_error_response(400, "invalid_request_error", f"Failed to process request data: {e}")
331
 
332
  is_streaming = openai_payload.get("stream", False)
333
+ requested_model = openai_payload.get("model", "unknown_model")
334
 
 
335
  target_headers = { "Content-Type": "application/json" }
336
  if OPENAI_API_KEY:
337
  target_headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
 
339
  else:
340
  logger.debug(f"[{request_id}] No OPENAI_API_KEY configured for upstream request.")
341
 
342
+ if logger.level("DEBUG").no >= logger.level(log_level).no:
 
343
  logged_headers = target_headers.copy()
344
+ if "Authorization" in logged_headers: logged_headers["Authorization"] = "Bearer [REDACTED]"
 
345
  logger.debug(f"[{request_id}] Sending request to upstream API.")
346
  logger.debug(f"[{request_id}] Upstream Headers: {json.dumps(logged_headers)}")
347
  try:
348
  payload_str = json.dumps(openai_payload, indent=2)
349
  max_log_len = 1024
350
+ logger.debug(f"[{request_id}] Upstream Payload {'(truncated)' if len(payload_str) > max_log_len else ''}: {payload_str[:max_log_len]}{'...' if len(payload_str) > max_log_len else ''}")
 
 
 
351
  except Exception as log_e:
352
  logger.warning(f"[{request_id}] Could not serialize or log upstream payload: {log_e}")
353
  else:
354
+ logger.debug(f"[{request_id}] Sending request to upstream API...")
 
355
 
356
  try:
357
+ target_request = client.build_request("POST", OPENAI_API_ENDPOINT, headers=target_headers, json=openai_payload)
 
 
 
 
 
 
 
 
358
  response = await client.send(target_request, stream=is_streaming)
359
+ response.raise_for_status()
360
 
 
 
 
 
361
  if is_streaming:
362
  logger.info(f"[{request_id}] Upstream response is streaming. Starting SSE conversion.")
363
  return StreamingResponse(
364
  stream_openai_response_to_claude_events(response, request_id, requested_model),
365
  media_type="text/event-stream",
366
+ headers={"X-Content-Type-Options": "nosniff", "Cache-Control": "no-cache", "Connection": "keep-alive"}
 
 
 
 
367
  )
368
  else:
369
  logger.info(f"[{request_id}] Upstream response is non-streaming. Converting.")
370
+ openai_response_data = response.json()
371
  logger.debug(f"[{request_id}] Received non-streaming response from upstream.")
372
  try:
373
  claude_response_data = openai_response_to_claude_response(openai_response_data, request_id)
 
379
  logger.exception(f"[{request_id}] Unexpected error converting non-streaming response: {e}")
380
  return create_error_response(500, "internal_server_error", "Unexpected error processing upstream response.")
381
 
 
 
382
  except httpx.HTTPStatusError as e:
383
  status_code = e.response.status_code
384
  error_detail_text = "[Could not decode error response]"
385
  try:
386
+ error_detail = e.response.json(); error_detail_text = json.dumps(error_detail)
 
387
  except json.JSONDecodeError:
388
+ try: error_detail_text = e.response.text
389
+ except Exception: logger.warning(f"[{request_id}] Could not read error response body as text.")
 
 
 
390
  logger.error(f"[{request_id}] HTTP error from target endpoint ({status_code}). Response snippet: {error_detail_text[:200]}...")
 
391
  if status_code == 400: err_type, msg = "invalid_request_error", f"Upstream API Bad Request ({status_code})."
392
  elif status_code == 401: err_type, msg = "authentication_error", f"Authentication failed with upstream API ({status_code})."
393
  elif status_code == 403: err_type, msg = "permission_error", f"Forbidden by upstream API ({status_code})."
394
  elif status_code == 429: err_type, msg = "rate_limit_error", f"Rate limit exceeded with upstream API ({status_code})."
395
  elif status_code >= 500: err_type, msg = "api_error", f"Upstream API unavailable or encountered an error ({status_code})."
396
  else: err_type, msg = "api_error", f"Received unexpected error from upstream API ({status_code})."
 
397
  return create_error_response(status_code, err_type, msg)
 
398
  except httpx.TimeoutException:
399
  logger.error(f"[{request_id}] Request to target endpoint timed out ({READ_TIMEOUT}s).")
400
  return create_error_response(504, "api_error", "Gateway Timeout: Request to upstream API timed out.")
 
418
  from dotenv import load_dotenv
419
  load_dotenv()
420
  logger.info("Loaded environment variables from .env file (if present).")
 
421
  PROXY_API_KEYS_STR = os.getenv("PROXY_API_KEYS", "")
422
  VALID_API_KEYS = set(key.strip() for key in PROXY_API_KEYS_STR.split(',') if key.strip())
423
  OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions")
424
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
425
  HTTP_PROXY = os.getenv("HTTP_PROXY")
426
  log_level = os.getenv("LOG_LEVEL", "INFO").upper()
 
427
  logger.remove()
428
  logger.add(sys.stderr, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
429
  logger.info(f"Log level set to: {log_level}")
 
434
  port = int(os.getenv("PORT", 7860))
435
  host = os.getenv("HOST", "0.0.0.0")
436
  log_config_level = log_level.lower() if log_level in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "TRACE"] else "info"
 
437
  logger.info(f"Starting Uvicorn server on {host}:{port}")
438
+ uvicorn.run("proxy_server:app", host=host, port=port, reload=True, log_level=log_config_level)
439
+