extraplus commited on
Commit
b3243bf
·
verified ·
1 Parent(s): 73f6a2f

Upload 2 files

Browse files
Files changed (2) hide show
  1. backend/config.py +1 -1
  2. backend/model_manager.py +143 -9
backend/config.py CHANGED
@@ -24,7 +24,7 @@ DATA_DIR.mkdir(exist_ok=True)
24
  MODELS_DIR.mkdir(exist_ok=True)
25
 
26
  class Settings(BaseSettings):
27
- model_config = SettingsConfigDict(env_file=".env", enable_decoding=False)
28
 
29
  # App settings
30
  APP_NAME: str = "GAKR AI Chatbot"
 
24
  MODELS_DIR.mkdir(exist_ok=True)
25
 
26
  class Settings(BaseSettings):
27
+ model_config = SettingsConfigDict(env_file=".env", enable_decoding=False, extra="ignore")
28
 
29
  # App settings
30
  APP_NAME: str = "GAKR AI Chatbot"
backend/model_manager.py CHANGED
@@ -1,9 +1,13 @@
1
  """Model manager that proxies inference requests to a Cloudflare Worker."""
2
  import asyncio
 
3
  import json
4
  import platform
 
 
5
  from datetime import datetime
6
  from typing import Any, AsyncGenerator, Dict, List, Optional
 
7
 
8
  import httpx
9
 
@@ -73,6 +77,117 @@ class ModelManager:
73
  # Worker communication #
74
  # ------------------------------------------------------------------ #
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def load_model(self, force_reload: bool = False) -> bool:
77
  """Verify that the Cloudflare Worker is healthy and reachable."""
78
  if self._worker_healthy and not force_reload:
@@ -90,8 +205,22 @@ class ModelManager:
90
  return True
91
  self._last_error = f"Worker returned status: {data.get('status')}"
92
  except Exception as exc:
93
- self._last_error = f"Worker health check failed: {exc}"
94
- self._worker_healthy = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return False
96
 
97
  def unload_model(self):
@@ -101,13 +230,18 @@ class ModelManager:
101
 
102
  def _call_worker_sync(self, payload: dict) -> dict:
103
  """Synchronous POST to worker /chat endpoint."""
104
- with httpx.Client(timeout=180, trust_env=False) as client:
105
- resp = client.post(f"{self.worker_url}/chat", json=payload)
106
- if resp.status_code >= 400:
107
- body = resp.text[:500]
108
- print(f"[WORKER] {resp.status_code} from {self.worker_url}/chat: {body}")
109
- resp.raise_for_status()
110
- return resp.json()
 
 
 
 
 
111
 
112
  async def _stream_worker_sse(self, payload: dict) -> AsyncGenerator[dict, None]:
113
  """Streaming POST to worker /chat, yields parsed SSE events."""
 
1
  """Model manager that proxies inference requests to a Cloudflare Worker."""
2
  import asyncio
3
+ import http.client
4
  import json
5
  import platform
6
+ import socket
7
+ import ssl
8
  from datetime import datetime
9
  from typing import Any, AsyncGenerator, Dict, List, Optional
10
+ from urllib.parse import urlparse
11
 
12
  import httpx
13
 
 
77
  # Worker communication #
78
  # ------------------------------------------------------------------ #
79
 
80
+ @staticmethod
81
+ def _should_try_dns_fallback(exc: Exception) -> bool:
82
+ message = str(exc).lower()
83
+ indicators = [
84
+ "no address associated with hostname",
85
+ "name or service not known",
86
+ "temporary failure in name resolution",
87
+ "getaddrinfo failed",
88
+ "nodename nor servname provided",
89
+ ]
90
+ return any(indicator in message for indicator in indicators)
91
+
92
+ def _resolve_worker_ips(self) -> List[str]:
93
+ parsed = urlparse(self.worker_url)
94
+ hostname = parsed.hostname or ""
95
+ if not hostname:
96
+ return []
97
+
98
+ doh_endpoints = [
99
+ (
100
+ "https://cloudflare-dns.com/dns-query",
101
+ {"name": hostname, "type": "A"},
102
+ {"accept": "application/dns-json"},
103
+ ),
104
+ (
105
+ "https://dns.google/resolve",
106
+ {"name": hostname, "type": "A"},
107
+ {},
108
+ ),
109
+ ]
110
+
111
+ for endpoint, params, headers in doh_endpoints:
112
+ try:
113
+ with httpx.Client(timeout=10, trust_env=False, headers=headers) as client:
114
+ response = client.get(endpoint, params=params)
115
+ response.raise_for_status()
116
+ payload = response.json()
117
+ answers = payload.get("Answer") or []
118
+ ips = [
119
+ str(item.get("data") or "").strip()
120
+ for item in answers
121
+ if str(item.get("type") or "") == "1" and str(item.get("data") or "").strip()
122
+ ]
123
+ if ips:
124
+ return ips
125
+ except Exception:
126
+ continue
127
+
128
+ return []
129
+
130
+ def _request_worker_via_resolved_ip(self, method: str, path: str, payload: Optional[dict], timeout: int) -> dict:
131
+ parsed = urlparse(self.worker_url)
132
+ hostname = parsed.hostname or ""
133
+ if not hostname:
134
+ raise RuntimeError("WORKER_URL hostname is empty")
135
+
136
+ port = parsed.port or (443 if (parsed.scheme or "https") == "https" else 80)
137
+ base_path = (parsed.path or "").rstrip("/")
138
+ request_path = f"{base_path}{path}" or "/"
139
+ body = b""
140
+ headers = {
141
+ "Host": hostname,
142
+ "Accept": "application/json",
143
+ "Connection": "close",
144
+ "User-Agent": "gakrchat-backend/1.0",
145
+ }
146
+
147
+ if payload is not None:
148
+ body = json.dumps(payload).encode("utf-8")
149
+ headers["Content-Type"] = "application/json"
150
+ headers["Content-Length"] = str(len(body))
151
+
152
+ request_lines = [f"{method} {request_path} HTTP/1.1"]
153
+ request_lines.extend(f"{key}: {value}" for key, value in headers.items())
154
+ request_bytes = ("\r\n".join(request_lines) + "\r\n\r\n").encode("utf-8") + body
155
+
156
+ ips = self._resolve_worker_ips()
157
+ if not ips:
158
+ raise RuntimeError(f"Unable to resolve worker hostname via DNS-over-HTTPS: {hostname}")
159
+
160
+ ssl_context = ssl.create_default_context()
161
+ last_error: Optional[Exception] = None
162
+
163
+ for ip_address in ips:
164
+ try:
165
+ with socket.create_connection((ip_address, port), timeout=timeout) as raw_socket:
166
+ with ssl_context.wrap_socket(raw_socket, server_hostname=hostname) as tls_socket:
167
+ tls_socket.settimeout(timeout)
168
+ tls_socket.sendall(request_bytes)
169
+ response = http.client.HTTPResponse(tls_socket)
170
+ response.begin()
171
+ response_body = response.read()
172
+
173
+ if response.status >= 400:
174
+ body_preview = response_body.decode("utf-8", errors="replace")[:500]
175
+ raise httpx.HTTPStatusError(
176
+ f"Worker returned status {response.status}: {body_preview}",
177
+ request=None,
178
+ response=None,
179
+ )
180
+
181
+ try:
182
+ return json.loads(response_body.decode("utf-8"))
183
+ except json.JSONDecodeError as exc:
184
+ raise RuntimeError("Worker returned non-JSON response during DNS fallback") from exc
185
+ except Exception as exc:
186
+ last_error = exc
187
+ continue
188
+
189
+ raise RuntimeError(f"Worker DNS fallback failed: {last_error}")
190
+
191
  def load_model(self, force_reload: bool = False) -> bool:
192
  """Verify that the Cloudflare Worker is healthy and reachable."""
193
  if self._worker_healthy and not force_reload:
 
205
  return True
206
  self._last_error = f"Worker returned status: {data.get('status')}"
207
  except Exception as exc:
208
+ if self._should_try_dns_fallback(exc):
209
+ try:
210
+ data = self._request_worker_via_resolved_ip("GET", "/health", None, timeout=10)
211
+ if data.get("status") == "ok":
212
+ self._worker_healthy = True
213
+ self._worker_model_name = data.get("model")
214
+ self._last_error = None
215
+ print(f"Worker healthy via DNS fallback: model={self._worker_model_name}")
216
+ return True
217
+ self._last_error = f"Worker returned status: {data.get('status')}"
218
+ except Exception as fallback_exc:
219
+ self._last_error = f"Worker health check failed: {exc}; DNS fallback failed: {fallback_exc}"
220
+ self._worker_healthy = False
221
+ else:
222
+ self._last_error = f"Worker health check failed: {exc}"
223
+ self._worker_healthy = False
224
  return False
225
 
226
  def unload_model(self):
 
230
 
231
  def _call_worker_sync(self, payload: dict) -> dict:
232
  """Synchronous POST to worker /chat endpoint."""
233
+ try:
234
+ with httpx.Client(timeout=180, trust_env=False) as client:
235
+ resp = client.post(f"{self.worker_url}/chat", json=payload)
236
+ if resp.status_code >= 400:
237
+ body = resp.text[:500]
238
+ print(f"[WORKER] {resp.status_code} from {self.worker_url}/chat: {body}")
239
+ resp.raise_for_status()
240
+ return resp.json()
241
+ except Exception as exc:
242
+ if self._should_try_dns_fallback(exc):
243
+ return self._request_worker_via_resolved_ip("POST", "/chat", payload, timeout=180)
244
+ raise
245
 
246
  async def _stream_worker_sse(self, payload: dict) -> AsyncGenerator[dict, None]:
247
  """Streaming POST to worker /chat, yields parsed SSE events."""