File size: 8,920 Bytes
2446f5f
500ef17
5bad7a1
063d7d5
 
eab2c9c
c90334d
7ee09b9
61655b8
6a181af
7ee09b9
83583ba
6a181af
02594ce
6a181af
61655b8
02594ce
61655b8
83583ba
61655b8
 
 
 
 
83583ba
6a181af
02594ce
6a181af
61655b8
6a181af
61655b8
6a181af
 
61655b8
 
6a181af
8c7c71f
83583ba
de2331b
61655b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a42e3f7
61655b8
063d7d5
 
61655b8
 
 
 
063d7d5
61655b8
 
063d7d5
61655b8
a42e3f7
83583ba
a42e3f7
 
 
83583ba
6a181af
5bad7a1
83583ba
61655b8
 
 
 
a42e3f7
61655b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2446f5f
61655b8
2446f5f
7ee09b9
83583ba
61655b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
063d7d5
61655b8
 
 
 
6a181af
 
83583ba
6a181af
61655b8
83583ba
6a181af
 
 
61655b8
83583ba
61655b8
 
 
 
6a181af
61655b8
83583ba
6a181af
eab2c9c
6a181af
61655b8
eab2c9c
6a181af
83583ba
61655b8
6a181af
b3b4e9a
61655b8
 
6a181af
b3b4e9a
6a181af
 
83583ba
6a181af
b3b4e9a
83583ba
61655b8
 
 
83583ba
61655b8
 
83583ba
6a181af
a42e3f7
61655b8
 
 
 
 
 
83583ba
61655b8
a42e3f7
61655b8
6a181af
b3b4e9a
 
61655b8
a42e3f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import httpx
from fastapi import FastAPI, Request, HTTPException
from starlette.responses import StreamingResponse, JSONResponse
from starlette.background import BackgroundTask
import os
import random
import logging
import time
import json
from contextlib import asynccontextmanager

# --- Production-Ready Configuration ---
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
    level=LOG_LEVEL,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# URL to fetch the list of all available models and their endpoints
ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts")

# Retry logic configuration
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
DEFAULT_RETRY_CODES = "429,500,502,503,504"
RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
try:
    RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
    logger.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
except ValueError:
    logger.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
    RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}

# --- Helper Functions ---

def generate_random_ip():
    """Generates a random, valid-looking IPv4 address."""
    return ".".join(str(random.randint(1, 254)) for _ in range(4))

async def fetch_and_cache_models(app: FastAPI):
    """
    Fetches the list of public artifacts and caches a routing table.
    This runs once on application startup.
    """
    logger.info(f"Fetching model artifacts from: {ARTIFACT_URL}")
    model_routing_table = {}
    try:
        async with httpx.AsyncClient() as client:
            response = await client.get(ARTIFACT_URL, timeout=30.0)
            response.raise_for_status()
            artifacts = response.json()

        for artifact in artifacts:
            model_name = artifact.get("artifact_metadata", {}).get("artifact_name")
            endpoints = artifact.get("endpoints", [])
            
            # We only care about models that have a running endpoint
            if model_name and endpoints:
                # A model could have multiple endpoints, we'll just use the first one
                # A more advanced setup could load-balance between them
                endpoint_url = endpoints[0].get("endpoint_url")
                if endpoint_url:
                    model_routing_table[model_name] = endpoint_url
        
        if not model_routing_table:
            logger.warning("No active model endpoints found from artifact URL.")
        else:
            logger.info(f"Successfully loaded {len(model_routing_table)} active models.")
            for name, url in model_routing_table.items():
                logger.debug(f"  - Model: '{name}' -> Endpoint: '{url}'")

    except httpx.RequestError as e:
        logger.critical(f"Failed to fetch model artifacts on startup: {e}")
        # In a real-world scenario, you might want the app to fail starting
        # or handle this more gracefully. For now, we start with an empty table.
    except Exception as e:
        logger.critical(f"An unexpected error occurred during model fetching: {e}")
        
    app.state.model_routing_table = model_routing_table


# --- HTTPX Client Lifecycle Management ---

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Manages the app's lifecycle for startup and shutdown."""
    # Create a single, long-lived HTTP client for forwarding requests
    # No base_url as we will be calling different hosts dynamically
    async with httpx.AsyncClient(timeout=None) as client:
        app.state.http_client = client
        # Fetch and cache model routes on startup
        await fetch_and_cache_models(app)
        yield
    logger.info("Application shutdown complete.")

# Initialize the FastAPI app with the lifespan manager and disabled docs
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)

# --- API Endpoints ---

@app.get("/")
async def health_check():
    """Provides a basic health check endpoint."""
    return JSONResponse({
        "status": "ok",
        "active_models": len(app.state.model_routing_table)
    })

@app.get("/v1/models")
async def list_models(request: Request):
    """
    Lists all available models discovered at startup.
    Formatted to be compatible with the OpenAI API.
    """
    model_routing_table = request.app.state.model_routing_table
    model_list = [
        {
            "id": model_id,
            "object": "model",
            "created": int(time.time()),
            "owned_by": "gmi-serving",
        }
        for model_id in model_routing_table.keys()
    ]
    return JSONResponse(content={"object": "list", "data": model_list})


@app.post("/v1/chat/completions")
async def chat_completions_proxy(request: Request):
    """
    Forwards chat completion requests to the correct model endpoint.
    """
    start_time = time.monotonic()
    
    # --- 1. Get Model Name and Find Target Host ---
    body = await request.body()
    try:
        data = json.loads(body)
        model_name = data.get("model")
        if not model_name:
            raise HTTPException(status_code=400, detail="Missing 'model' field in request body.")
    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="Invalid JSON in request body.")

    model_routing_table = request.app.state.model_routing_table
    target_host = model_routing_table.get(model_name)

    if not target_host:
        raise HTTPException(
            status_code=404,
            detail=f"Model '{model_name}' not found or is not currently active."
        )

    # --- 2. Prepare and Forward the Request ---
    client: httpx.AsyncClient = request.app.state.http_client
    
    # Construct the full URL to the backend service
    target_url = f"https://{target_host}{request.url.path}"
    
    request_headers = dict(request.headers)
    request_headers.pop("host", None)

    random_ip = generate_random_ip()
    spoof_headers = {
        "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
        "x-forwarded-for": random_ip,
        "x-real-ip": random_ip,
    }
    request_headers.update(spoof_headers)

    logger.info(
        f"Routing request for model '{model_name}' to {target_url} "
        f"(Client: '{request.client.host}', Spoofed IP: {random_ip})"
    )
    
    # --- 3. Execute with Retry Logic ---
    last_exception = None
    for attempt in range(MAX_RETRIES):
        try:
            rp_req = client.build_request(
                method=request.method, url=target_url, headers=request_headers, content=body
            )
            rp_resp = await client.send(rp_req, stream=True)

            # If status is not retryable OR it's the last attempt, stream the response
            if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
                duration_ms = (time.monotonic() - start_time) * 1000
                log_func = logger.info if rp_resp.is_success else logger.warning
                log_func(f"Request finished for '{model_name}': {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
                
                return StreamingResponse(
                    rp_resp.aiter_raw(),
                    status_code=rp_resp.status_code,
                    headers=rp_resp.headers,
                    background=BackgroundTask(rp_resp.aclose),
                )

            # Otherwise, log and prepare for retry
            logger.warning(
                f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {rp_resp.status_code}. Retrying..."
            )
            await rp_resp.aclose() # Ensure the connection is closed before retrying
            await asyncio.sleep(1 * (2 ** attempt)) # Exponential backoff

        except httpx.ConnectError as e:
            last_exception = e
            logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with connection error: {e}")
        
        except Exception as e:
            last_exception = e
            logger.error(f"An unexpected error occurred during request forwarding: {e}")
            break # Don't retry on unexpected errors

    # --- 4. Handle Final Failure ---
    duration_ms = (time.monotonic() - start_time) * 1000
    logger.critical(f"Request failed for model '{model_name}' after {MAX_RETRIES} attempts. Cannot connect to target: {target_url}. Latency: {duration_ms:.2f}ms")
    
    raise HTTPException(
        status_code=502,
        detail=f"Bad Gateway: Cannot connect to model backend for '{model_name}' after {MAX_RETRIES} attempts. Last error: {last_exception}"
    )