rkihacker commited on
Commit
61655b8
·
verified ·
1 Parent(s): 83583ba

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +137 -42
main.py CHANGED
@@ -6,100 +6,187 @@ import os
6
  import random
7
  import logging
8
  import time
 
9
  from contextlib import asynccontextmanager
10
 
11
  # --- Production-Ready Configuration ---
12
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
13
  logging.basicConfig(
14
  level=LOG_LEVEL,
15
- format='%(asctime)s - %(levelname)s - %(message)s'
16
  )
 
17
 
18
- TARGET_URL = os.getenv("TARGET_URL", "https://api.gmi-serving.com")
19
- MAX_RETRIES = int(os.getenv("MAX_RETRIES", "15"))
 
 
 
20
  DEFAULT_RETRY_CODES = "429,500,502,503,504"
21
  RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
22
  try:
23
  RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
24
- logging.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
25
  except ValueError:
26
- logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
27
  RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
28
 
29
- # --- Helper Function ---
 
30
  def generate_random_ip():
31
  """Generates a random, valid-looking IPv4 address."""
32
  return ".".join(str(random.randint(1, 254)) for _ in range(4))
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # --- HTTPX Client Lifecycle Management ---
 
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
- """Manages the lifecycle of the HTTPX client."""
38
- async with httpx.AsyncClient(base_url=TARGET_URL, timeout=None) as client:
 
 
39
  app.state.http_client = client
 
 
40
  yield
 
41
 
42
  # Initialize the FastAPI app with the lifespan manager and disabled docs
43
  app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
44
 
45
  # --- API Endpoints ---
46
 
47
- # 1. Health Check Route (Defined FIRST)
48
- # This specific route will be matched before the catch-all proxy route.
49
  @app.get("/")
50
  async def health_check():
51
  """Provides a basic health check endpoint."""
52
- return JSONResponse({"status": "ok", "target": TARGET_URL})
 
 
 
53
 
54
- # 2. Catch-All Reverse Proxy Route (Defined SECOND)
55
- # This will capture ALL other requests (e.g., /completions, /v1/models, etc.)
56
- # and forward them. This eliminates any redirect issues.
57
- @app.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
58
- async def reverse_proxy_handler(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  """
60
- A catch-all reverse proxy that forwards requests to the target URL with
61
- enhanced retry logic and latency logging.
62
  """
63
  start_time = time.monotonic()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  client: httpx.AsyncClient = request.app.state.http_client
66
- url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8"))
67
-
 
 
68
  request_headers = dict(request.headers)
69
  request_headers.pop("host", None)
70
 
71
  random_ip = generate_random_ip()
72
- logging.info(f"Client '{request.client.host}' proxied with spoofed IP: {random_ip} for path: {url.path}")
73
-
74
- specific_headers = {
75
  "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
76
  "x-forwarded-for": random_ip,
77
  "x-real-ip": random_ip,
78
- "x-originating-ip": random_ip,
79
- "x-remote-ip": random_ip,
80
- "x-remote-addr": random_ip,
81
- "x-host": random_ip,
82
- "x-forwarded-host": random_ip,
83
  }
84
- request_headers.update(specific_headers)
85
-
86
- if "authorization" in request.headers:
87
- request_headers["authorization"] = request.headers["authorization"]
88
 
89
- body = await request.body()
 
 
 
90
 
 
91
  last_exception = None
92
  for attempt in range(MAX_RETRIES):
93
  try:
94
  rp_req = client.build_request(
95
- method=request.method, url=url, headers=request_headers, content=body
96
  )
97
  rp_resp = await client.send(rp_req, stream=True)
98
 
 
99
  if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
100
  duration_ms = (time.monotonic() - start_time) * 1000
101
- log_func = logging.info if rp_resp.is_success else logging.warning
102
- log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
103
 
104
  return StreamingResponse(
105
  rp_resp.aiter_raw(),
@@ -108,19 +195,27 @@ async def reverse_proxy_handler(request: Request):
108
  background=BackgroundTask(rp_resp.aclose),
109
  )
110
 
111
- logging.warning(
112
- f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with status {rp_resp.status_code}. Retrying..."
 
113
  )
114
- await rp_resp.aclose()
 
115
 
116
  except httpx.ConnectError as e:
117
  last_exception = e
118
- logging.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with connection error: {e}")
 
 
 
 
 
119
 
 
120
  duration_ms = (time.monotonic() - start_time) * 1000
121
- logging.critical(f"Request failed, cannot connect to target: {request.method} {request.url.path} status_code=502 latency={duration_ms:.2f}ms")
122
 
123
  raise HTTPException(
124
  status_code=502,
125
- detail=f"Bad Gateway: Cannot connect to target service after {MAX_RETRIES} attempts. {last_exception}"
126
  )
 
6
  import random
7
  import logging
8
  import time
9
+ import json
10
  from contextlib import asynccontextmanager
11
 
12
  # --- Production-Ready Configuration ---
13
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
14
  logging.basicConfig(
15
  level=LOG_LEVEL,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
  )
18
+ logger = logging.getLogger(__name__)
19
 
20
+ # URL to fetch the list of all available models and their endpoints
21
+ ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts")
22
+
23
+ # Retry logic configuration
24
+ MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
25
  DEFAULT_RETRY_CODES = "429,500,502,503,504"
26
  RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
27
  try:
28
  RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
29
+ logger.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
30
  except ValueError:
31
+ logger.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
32
  RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
33
 
34
+ # --- Helper Functions ---
35
+
36
  def generate_random_ip():
37
  """Generates a random, valid-looking IPv4 address."""
38
  return ".".join(str(random.randint(1, 254)) for _ in range(4))
39
 
40
+ async def fetch_and_cache_models(app: FastAPI):
41
+ """
42
+ Fetches the list of public artifacts and caches a routing table.
43
+ This runs once on application startup.
44
+ """
45
+ logger.info(f"Fetching model artifacts from: {ARTIFACT_URL}")
46
+ model_routing_table = {}
47
+ try:
48
+ async with httpx.AsyncClient() as client:
49
+ response = await client.get(ARTIFACT_URL, timeout=30.0)
50
+ response.raise_for_status()
51
+ artifacts = response.json()
52
+
53
+ for artifact in artifacts:
54
+ model_name = artifact.get("artifact_metadata", {}).get("artifact_name")
55
+ endpoints = artifact.get("endpoints", [])
56
+
57
+ # We only care about models that have a running endpoint
58
+ if model_name and endpoints:
59
+ # A model could have multiple endpoints, we'll just use the first one
60
+ # A more advanced setup could load-balance between them
61
+ endpoint_url = endpoints[0].get("endpoint_url")
62
+ if endpoint_url:
63
+ model_routing_table[model_name] = endpoint_url
64
+
65
+ if not model_routing_table:
66
+ logger.warning("No active model endpoints found from artifact URL.")
67
+ else:
68
+ logger.info(f"Successfully loaded {len(model_routing_table)} active models.")
69
+ for name, url in model_routing_table.items():
70
+ logger.debug(f" - Model: '{name}' -> Endpoint: '{url}'")
71
+
72
+ except httpx.RequestError as e:
73
+ logger.critical(f"Failed to fetch model artifacts on startup: {e}")
74
+ # In a real-world scenario, you might want the app to fail starting
75
+ # or handle this more gracefully. For now, we start with an empty table.
76
+ except Exception as e:
77
+ logger.critical(f"An unexpected error occurred during model fetching: {e}")
78
+
79
+ app.state.model_routing_table = model_routing_table
80
+
81
+
82
  # --- HTTPX Client Lifecycle Management ---
83
+
84
  @asynccontextmanager
85
  async def lifespan(app: FastAPI):
86
+ """Manages the app's lifecycle for startup and shutdown."""
87
+ # Create a single, long-lived HTTP client for forwarding requests
88
+ # No base_url as we will be calling different hosts dynamically
89
+ async with httpx.AsyncClient(timeout=None) as client:
90
  app.state.http_client = client
91
+ # Fetch and cache model routes on startup
92
+ await fetch_and_cache_models(app)
93
  yield
94
+ logger.info("Application shutdown complete.")
95
 
96
  # Initialize the FastAPI app with the lifespan manager and disabled docs
97
  app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
98
 
99
  # --- API Endpoints ---
100
 
 
 
101
  @app.get("/")
102
  async def health_check():
103
  """Provides a basic health check endpoint."""
104
+ return JSONResponse({
105
+ "status": "ok",
106
+ "active_models": len(app.state.model_routing_table)
107
+ })
108
 
109
+ @app.get("/v1/models")
110
+ async def list_models(request: Request):
111
+ """
112
+ Lists all available models discovered at startup.
113
+ Formatted to be compatible with the OpenAI API.
114
+ """
115
+ model_routing_table = request.app.state.model_routing_table
116
+ model_list = [
117
+ {
118
+ "id": model_id,
119
+ "object": "model",
120
+ "created": int(time.time()),
121
+ "owned_by": "gmi-serving",
122
+ }
123
+ for model_id in model_routing_table.keys()
124
+ ]
125
+ return JSONResponse(content={"object": "list", "data": model_list})
126
+
127
+
128
+ @app.post("/v1/chat/completions")
129
+ async def chat_completions_proxy(request: Request):
130
  """
131
+ Forwards chat completion requests to the correct model endpoint.
 
132
  """
133
  start_time = time.monotonic()
134
 
135
+ # --- 1. Get Model Name and Find Target Host ---
136
+ body = await request.body()
137
+ try:
138
+ data = json.loads(body)
139
+ model_name = data.get("model")
140
+ if not model_name:
141
+ raise HTTPException(status_code=400, detail="Missing 'model' field in request body.")
142
+ except json.JSONDecodeError:
143
+ raise HTTPException(status_code=400, detail="Invalid JSON in request body.")
144
+
145
+ model_routing_table = request.app.state.model_routing_table
146
+ target_host = model_routing_table.get(model_name)
147
+
148
+ if not target_host:
149
+ raise HTTPException(
150
+ status_code=404,
151
+ detail=f"Model '{model_name}' not found or is not currently active."
152
+ )
153
+
154
+ # --- 2. Prepare and Forward the Request ---
155
  client: httpx.AsyncClient = request.app.state.http_client
156
+
157
+ # Construct the full URL to the backend service
158
+ target_url = f"https://{target_host}{request.url.path}"
159
+
160
  request_headers = dict(request.headers)
161
  request_headers.pop("host", None)
162
 
163
  random_ip = generate_random_ip()
164
+ spoof_headers = {
 
 
165
  "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
166
  "x-forwarded-for": random_ip,
167
  "x-real-ip": random_ip,
 
 
 
 
 
168
  }
169
+ request_headers.update(spoof_headers)
 
 
 
170
 
171
+ logger.info(
172
+ f"Routing request for model '{model_name}' to {target_url} "
173
+ f"(Client: '{request.client.host}', Spoofed IP: {random_ip})"
174
+ )
175
 
176
+ # --- 3. Execute with Retry Logic ---
177
  last_exception = None
178
  for attempt in range(MAX_RETRIES):
179
  try:
180
  rp_req = client.build_request(
181
+ method=request.method, url=target_url, headers=request_headers, content=body
182
  )
183
  rp_resp = await client.send(rp_req, stream=True)
184
 
185
+ # If status is not retryable OR it's the last attempt, stream the response
186
  if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
187
  duration_ms = (time.monotonic() - start_time) * 1000
188
+ log_func = logger.info if rp_resp.is_success else logger.warning
189
+ log_func(f"Request finished for '{model_name}': {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
190
 
191
  return StreamingResponse(
192
  rp_resp.aiter_raw(),
 
195
  background=BackgroundTask(rp_resp.aclose),
196
  )
197
 
198
+ # Otherwise, log and prepare for retry
199
+ logger.warning(
200
+ f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {rp_resp.status_code}. Retrying..."
201
  )
202
+ await rp_resp.aclose() # Ensure the connection is closed before retrying
203
+ await asyncio.sleep(1 * (2 ** attempt)) # Exponential backoff
204
 
205
  except httpx.ConnectError as e:
206
  last_exception = e
207
+ logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with connection error: {e}")
208
+
209
+ except Exception as e:
210
+ last_exception = e
211
+ logger.error(f"An unexpected error occurred during request forwarding: {e}")
212
+ break # Don't retry on unexpected errors
213
 
214
+ # --- 4. Handle Final Failure ---
215
  duration_ms = (time.monotonic() - start_time) * 1000
216
+ logger.critical(f"Request failed for model '{model_name}' after {MAX_RETRIES} attempts. Cannot connect to target: {target_url}. Latency: {duration_ms:.2f}ms")
217
 
218
  raise HTTPException(
219
  status_code=502,
220
+ detail=f"Bad Gateway: Cannot connect to model backend for '{model_name}' after {MAX_RETRIES} attempts. Last error: {last_exception}"
221
  )