File size: 8,920 Bytes
2446f5f 500ef17 5bad7a1 063d7d5 eab2c9c c90334d 7ee09b9 61655b8 6a181af 7ee09b9 83583ba 6a181af 02594ce 6a181af 61655b8 02594ce 61655b8 83583ba 61655b8 83583ba 6a181af 02594ce 6a181af 61655b8 6a181af 61655b8 6a181af 61655b8 6a181af 8c7c71f 83583ba de2331b 61655b8 a42e3f7 61655b8 063d7d5 61655b8 063d7d5 61655b8 063d7d5 61655b8 a42e3f7 83583ba a42e3f7 83583ba 6a181af 5bad7a1 83583ba 61655b8 a42e3f7 61655b8 2446f5f 61655b8 2446f5f 7ee09b9 83583ba 61655b8 063d7d5 61655b8 6a181af 83583ba 6a181af 61655b8 83583ba 6a181af 61655b8 83583ba 61655b8 6a181af 61655b8 83583ba 6a181af eab2c9c 6a181af 61655b8 eab2c9c 6a181af 83583ba 61655b8 6a181af b3b4e9a 61655b8 6a181af b3b4e9a 6a181af 83583ba 6a181af b3b4e9a 83583ba 61655b8 83583ba 61655b8 83583ba 6a181af a42e3f7 61655b8 83583ba 61655b8 a42e3f7 61655b8 6a181af b3b4e9a 61655b8 a42e3f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import httpx
from fastapi import FastAPI, Request, HTTPException
from starlette.responses import StreamingResponse, JSONResponse
from starlette.background import BackgroundTask
import os
import random
import logging
import time
import json
from contextlib import asynccontextmanager
# --- Production-Ready Configuration ---
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# URL to fetch the list of all available models and their endpoints
ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts")
# Retry logic configuration
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
DEFAULT_RETRY_CODES = "429,500,502,503,504"
RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
try:
RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
logger.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
except ValueError:
logger.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
# --- Helper Functions ---
def generate_random_ip():
"""Generates a random, valid-looking IPv4 address."""
return ".".join(str(random.randint(1, 254)) for _ in range(4))
async def fetch_and_cache_models(app: FastAPI):
"""
Fetches the list of public artifacts and caches a routing table.
This runs once on application startup.
"""
logger.info(f"Fetching model artifacts from: {ARTIFACT_URL}")
model_routing_table = {}
try:
async with httpx.AsyncClient() as client:
response = await client.get(ARTIFACT_URL, timeout=30.0)
response.raise_for_status()
artifacts = response.json()
for artifact in artifacts:
model_name = artifact.get("artifact_metadata", {}).get("artifact_name")
endpoints = artifact.get("endpoints", [])
# We only care about models that have a running endpoint
if model_name and endpoints:
# A model could have multiple endpoints, we'll just use the first one
# A more advanced setup could load-balance between them
endpoint_url = endpoints[0].get("endpoint_url")
if endpoint_url:
model_routing_table[model_name] = endpoint_url
if not model_routing_table:
logger.warning("No active model endpoints found from artifact URL.")
else:
logger.info(f"Successfully loaded {len(model_routing_table)} active models.")
for name, url in model_routing_table.items():
logger.debug(f" - Model: '{name}' -> Endpoint: '{url}'")
except httpx.RequestError as e:
logger.critical(f"Failed to fetch model artifacts on startup: {e}")
# In a real-world scenario, you might want the app to fail starting
# or handle this more gracefully. For now, we start with an empty table.
except Exception as e:
logger.critical(f"An unexpected error occurred during model fetching: {e}")
app.state.model_routing_table = model_routing_table
# --- HTTPX Client Lifecycle Management ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manages the app's lifecycle for startup and shutdown."""
# Create a single, long-lived HTTP client for forwarding requests
# No base_url as we will be calling different hosts dynamically
async with httpx.AsyncClient(timeout=None) as client:
app.state.http_client = client
# Fetch and cache model routes on startup
await fetch_and_cache_models(app)
yield
logger.info("Application shutdown complete.")
# Initialize the FastAPI app with the lifespan manager and disabled docs
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
# --- API Endpoints ---
@app.get("/")
async def health_check():
"""Provides a basic health check endpoint."""
return JSONResponse({
"status": "ok",
"active_models": len(app.state.model_routing_table)
})
@app.get("/v1/models")
async def list_models(request: Request):
"""
Lists all available models discovered at startup.
Formatted to be compatible with the OpenAI API.
"""
model_routing_table = request.app.state.model_routing_table
model_list = [
{
"id": model_id,
"object": "model",
"created": int(time.time()),
"owned_by": "gmi-serving",
}
for model_id in model_routing_table.keys()
]
return JSONResponse(content={"object": "list", "data": model_list})
@app.post("/v1/chat/completions")
async def chat_completions_proxy(request: Request):
"""
Forwards chat completion requests to the correct model endpoint.
"""
start_time = time.monotonic()
# --- 1. Get Model Name and Find Target Host ---
body = await request.body()
try:
data = json.loads(body)
model_name = data.get("model")
if not model_name:
raise HTTPException(status_code=400, detail="Missing 'model' field in request body.")
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON in request body.")
model_routing_table = request.app.state.model_routing_table
target_host = model_routing_table.get(model_name)
if not target_host:
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' not found or is not currently active."
)
# --- 2. Prepare and Forward the Request ---
client: httpx.AsyncClient = request.app.state.http_client
# Construct the full URL to the backend service
target_url = f"https://{target_host}{request.url.path}"
request_headers = dict(request.headers)
request_headers.pop("host", None)
random_ip = generate_random_ip()
spoof_headers = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
"x-forwarded-for": random_ip,
"x-real-ip": random_ip,
}
request_headers.update(spoof_headers)
logger.info(
f"Routing request for model '{model_name}' to {target_url} "
f"(Client: '{request.client.host}', Spoofed IP: {random_ip})"
)
# --- 3. Execute with Retry Logic ---
last_exception = None
for attempt in range(MAX_RETRIES):
try:
rp_req = client.build_request(
method=request.method, url=target_url, headers=request_headers, content=body
)
rp_resp = await client.send(rp_req, stream=True)
# If status is not retryable OR it's the last attempt, stream the response
if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
duration_ms = (time.monotonic() - start_time) * 1000
log_func = logger.info if rp_resp.is_success else logger.warning
log_func(f"Request finished for '{model_name}': {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
return StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
# Otherwise, log and prepare for retry
logger.warning(
f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {rp_resp.status_code}. Retrying..."
)
await rp_resp.aclose() # Ensure the connection is closed before retrying
await asyncio.sleep(1 * (2 ** attempt)) # Exponential backoff
except httpx.ConnectError as e:
last_exception = e
logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with connection error: {e}")
except Exception as e:
last_exception = e
logger.error(f"An unexpected error occurred during request forwarding: {e}")
break # Don't retry on unexpected errors
# --- 4. Handle Final Failure ---
duration_ms = (time.monotonic() - start_time) * 1000
logger.critical(f"Request failed for model '{model_name}' after {MAX_RETRIES} attempts. Cannot connect to target: {target_url}. Latency: {duration_ms:.2f}ms")
raise HTTPException(
status_code=502,
detail=f"Bad Gateway: Cannot connect to model backend for '{model_name}' after {MAX_RETRIES} attempts. Last error: {last_exception}"
) |