File size: 7,025 Bytes
2446f5f 500ef17 5bad7a1 063d7d5 eab2c9c c90334d 7ee09b9 063d7d5 daa63f8 56d0fcf 7ee09b9 8adf131 419de53 7ee09b9 daa63f8 eab2c9c 2446f5f daa63f8 063d7d5 7ee09b9 063d7d5 e50ca24 7ee09b9 063d7d5 6b64125 5bad7a1 b3b4e9a 5bad7a1 b3b4e9a 5bad7a1 2446f5f 5bad7a1 b3b4e9a 2446f5f 7ee09b9 063d7d5 7ee09b9 2446f5f eab2c9c 5bad7a1 eab2c9c 063d7d5 5e05fcc eab2c9c 063d7d5 6b64125 7ee09b9 6a8ddc4 eab2c9c 7ee09b9 eab2c9c 5bad7a1 b3b4e9a daa63f8 b3b4e9a 7ee09b9 5bad7a1 7ee09b9 eab2c9c 5bad7a1 eab2c9c 7ee09b9 b3b4e9a 5bad7a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import httpx
from fastapi import FastAPI, Request, HTTPException
from starlette.responses import StreamingResponse, JSONResponse
from starlette.background import BackgroundTask
import os
import random
import logging
import time
from contextlib import asynccontextmanager
import json
# --- Production-Ready Configuration ---
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s - %(levelname)s - %(message)s'
)
TARGET_URL = os.getenv("TARGET_URL", "https://api.gmi-serving.com")
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "10"))
DEFAULT_RETRY_CODES = "429,500,502,503,504"
RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
try:
RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
logging.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
except ValueError:
logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
# --- Helper Functions ---
def generate_random_ip():
"""Generates a random, valid-looking IPv4 address."""
return ".".join(str(random.randint(1, 254)) for _ in range(4))
async def modified_aiter_raw(original_aiter):
"""
An async generator that intercepts and modifies the streaming data chunks.
It adds a prefix to the 'id' and includes a 'provider' field.
"""
buffer = ""
async for chunk in original_aiter:
buffer += chunk.decode('utf-8')
while '\n' in buffer:
line, buffer = buffer.split('\n', 1)
if line.startswith('data:'):
try:
# Strip the "data: " prefix to get the JSON string
json_str = line[len('data: '):].strip()
# Process only if it's not the SSE termination message
if json_str and json_str != '[DONE]':
data = json.loads(json_str)
# Add 'NAI-' prefix to the id
if 'id' in data:
data['id'] = f"NAI-{data['id']}"
# Add the provider field
data['provider'] = 'TypeGPT'
# Reconstruct the SSE data line
modified_line = f"data: {json.dumps(data)}"
yield (modified_line + '\n').encode('utf-8')
else:
# Pass through messages like 'data: [DONE]'
yield (line + '\n').encode('utf-8')
except json.JSONDecodeError:
# If it's not valid JSON, pass it through as is
yield (line + '\n').encode('utf-8')
else:
# Pass through non-data lines (e.g., empty lines, comments)
yield (line + '\n').encode('utf-8')
# Yield any remaining data in the buffer
if buffer:
yield buffer.encode('utf-8')
# --- HTTPX Client Lifecycle Management ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manages the lifecycle of the HTTPX client."""
async with httpx.AsyncClient(base_url=TARGET_URL, timeout=None) as client:
app.state.http_client = client
yield
# Initialize the FastAPI app with the lifespan manager and disabled docs
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
# --- API Endpoints ---
# 1. Health Check Route (Defined FIRST)
# This specific route will be matched before the catch-all proxy route.
@app.get("/")
async def health_check():
"""Provides a basic health check endpoint."""
return JSONResponse({"status": "ok", "target": TARGET_URL})
# 2. Catch-All Reverse Proxy Route (Defined SECOND)
# This will capture ALL other requests (e.g., /completions, /v1/models, etc.)
# and forward them. This eliminates any redirect issues.
@app.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def reverse_proxy_handler(request: Request):
"""
A catch-all reverse proxy that forwards requests to the target URL with
enhanced retry logic and latency logging.
"""
start_time = time.monotonic()
client: httpx.AsyncClient = request.app.state.http_client
url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8"))
request_headers = dict(request.headers)
request_headers.pop("host", None)
random_ip = generate_random_ip()
logging.info(f"Client '{request.client.host}' proxied with spoofed IP: {random_ip} for path: {url.path}")
specific_headers = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
"x-forwarded-for": random_ip,
"x-real-ip": random_ip,
}
request_headers.update(specific_headers)
if "authorization" in request.headers:
request_headers["authorization"] = request.headers["authorization"]
body = await request.body()
last_exception = None
for attempt in range(MAX_RETRIES):
try:
rp_req = client.build_request(
method=request.method, url=url, headers=request_headers, content=body
)
rp_resp = await client.send(rp_req, stream=True)
if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
duration_ms = (time.monotonic() - start_time) * 1000
log_func = logging.info if rp_resp.is_success else logging.warning
log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
return StreamingResponse(
# Use the new async generator to modify the stream
modified_aiter_raw(rp_resp.aiter_raw()),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
logging.warning(
f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with status {rp_resp.status_code}. Retrying..."
)
await rp_resp.aclose()
except httpx.ConnectError as e:
last_exception = e
logging.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with connection error: {e}")
duration_ms = (time.monotonic() - start_time) * 1000
logging.critical(f"Request failed, cannot connect to target: {request.method} {request.url.path} status_code=502 latency={duration_ms:.2f}ms")
raise HTTPException(
status_code=502,
detail=f"Bad Gateway: Cannot connect to target service after {MAX_RETRIES} attempts. {last_exception}"
) |