import httpx from fastapi import FastAPI, Request, HTTPException from starlette.responses import StreamingResponse, JSONResponse from starlette.background import BackgroundTask import os import random import logging import time from contextlib import asynccontextmanager import json # --- Production-Ready Configuration --- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() logging.basicConfig( level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s' ) TARGET_URL = os.getenv("TARGET_URL", "https://api.gmi-serving.com") MAX_RETRIES = int(os.getenv("MAX_RETRIES", "10")) DEFAULT_RETRY_CODES = "429,500,502,503,504" RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES) try: RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')} logging.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}") except ValueError: logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}") RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')} # --- Helper Functions --- def generate_random_ip(): """Generates a random, valid-looking IPv4 address.""" return ".".join(str(random.randint(1, 254)) for _ in range(4)) async def modified_aiter_raw(original_aiter): """ An async generator that intercepts and modifies the streaming data chunks. It adds a prefix to the 'id' and includes a 'provider' field. """ buffer = "" async for chunk in original_aiter: buffer += chunk.decode('utf-8') while '\n' in buffer: line, buffer = buffer.split('\n', 1) if line.startswith('data:'): try: # Strip the "data: " prefix to get the JSON string json_str = line[len('data: '):].strip() # Process only if it's not the SSE termination message if json_str and json_str != '[DONE]': data = json.loads(json_str) # Add 'NAI-' prefix to the id if 'id' in data: data['id'] = f"NAI-{data['id']}" # Add the provider field data['provider'] = 'TypeGPT' # Reconstruct the SSE data line modified_line = f"data: {json.dumps(data)}" yield (modified_line + '\n').encode('utf-8') else: # Pass through messages like 'data: [DONE]' yield (line + '\n').encode('utf-8') except json.JSONDecodeError: # If it's not valid JSON, pass it through as is yield (line + '\n').encode('utf-8') else: # Pass through non-data lines (e.g., empty lines, comments) yield (line + '\n').encode('utf-8') # Yield any remaining data in the buffer if buffer: yield buffer.encode('utf-8') # --- HTTPX Client Lifecycle Management --- @asynccontextmanager async def lifespan(app: FastAPI): """Manages the lifecycle of the HTTPX client.""" async with httpx.AsyncClient(base_url=TARGET_URL, timeout=None) as client: app.state.http_client = client yield # Initialize the FastAPI app with the lifespan manager and disabled docs app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan) # --- API Endpoints --- # 1. Health Check Route (Defined FIRST) # This specific route will be matched before the catch-all proxy route. @app.get("/") async def health_check(): """Provides a basic health check endpoint.""" return JSONResponse({"status": "ok", "target": TARGET_URL}) # 2. Catch-All Reverse Proxy Route (Defined SECOND) # This will capture ALL other requests (e.g., /completions, /v1/models, etc.) # and forward them. This eliminates any redirect issues. @app.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]) async def reverse_proxy_handler(request: Request): """ A catch-all reverse proxy that forwards requests to the target URL with enhanced retry logic and latency logging. """ start_time = time.monotonic() client: httpx.AsyncClient = request.app.state.http_client url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8")) request_headers = dict(request.headers) request_headers.pop("host", None) random_ip = generate_random_ip() logging.info(f"Client '{request.client.host}' proxied with spoofed IP: {random_ip} for path: {url.path}") specific_headers = { "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36", "x-forwarded-for": random_ip, "x-real-ip": random_ip, } request_headers.update(specific_headers) if "authorization" in request.headers: request_headers["authorization"] = request.headers["authorization"] body = await request.body() last_exception = None for attempt in range(MAX_RETRIES): try: rp_req = client.build_request( method=request.method, url=url, headers=request_headers, content=body ) rp_resp = await client.send(rp_req, stream=True) if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1: duration_ms = (time.monotonic() - start_time) * 1000 log_func = logging.info if rp_resp.is_success else logging.warning log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms") return StreamingResponse( # Use the new async generator to modify the stream modified_aiter_raw(rp_resp.aiter_raw()), status_code=rp_resp.status_code, headers=rp_resp.headers, background=BackgroundTask(rp_resp.aclose), ) logging.warning( f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with status {rp_resp.status_code}. Retrying..." ) await rp_resp.aclose() except httpx.ConnectError as e: last_exception = e logging.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with connection error: {e}") duration_ms = (time.monotonic() - start_time) * 1000 logging.critical(f"Request failed, cannot connect to target: {request.method} {request.url.path} status_code=502 latency={duration_ms:.2f}ms") raise HTTPException( status_code=502, detail=f"Bad Gateway: Cannot connect to target service after {MAX_RETRIES} attempts. {last_exception}" )