Upload 2 files
Browse files- backend/config.py +1 -1
- backend/model_manager.py +143 -9
backend/config.py
CHANGED
|
@@ -24,7 +24,7 @@ DATA_DIR.mkdir(exist_ok=True)
|
|
| 24 |
MODELS_DIR.mkdir(exist_ok=True)
|
| 25 |
|
| 26 |
class Settings(BaseSettings):
|
| 27 |
-
model_config = SettingsConfigDict(env_file=".env", enable_decoding=False)
|
| 28 |
|
| 29 |
# App settings
|
| 30 |
APP_NAME: str = "GAKR AI Chatbot"
|
|
|
|
| 24 |
MODELS_DIR.mkdir(exist_ok=True)
|
| 25 |
|
| 26 |
class Settings(BaseSettings):
|
| 27 |
+
model_config = SettingsConfigDict(env_file=".env", enable_decoding=False, extra="ignore")
|
| 28 |
|
| 29 |
# App settings
|
| 30 |
APP_NAME: str = "GAKR AI Chatbot"
|
backend/model_manager.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
"""Model manager that proxies inference requests to a Cloudflare Worker."""
|
| 2 |
import asyncio
|
|
|
|
| 3 |
import json
|
| 4 |
import platform
|
|
|
|
|
|
|
| 5 |
from datetime import datetime
|
| 6 |
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
|
|
| 7 |
|
| 8 |
import httpx
|
| 9 |
|
|
@@ -73,6 +77,117 @@ class ModelManager:
|
|
| 73 |
# Worker communication #
|
| 74 |
# ------------------------------------------------------------------ #
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
def load_model(self, force_reload: bool = False) -> bool:
|
| 77 |
"""Verify that the Cloudflare Worker is healthy and reachable."""
|
| 78 |
if self._worker_healthy and not force_reload:
|
|
@@ -90,8 +205,22 @@ class ModelManager:
|
|
| 90 |
return True
|
| 91 |
self._last_error = f"Worker returned status: {data.get('status')}"
|
| 92 |
except Exception as exc:
|
| 93 |
-
self.
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
return False
|
| 96 |
|
| 97 |
def unload_model(self):
|
|
@@ -101,13 +230,18 @@ class ModelManager:
|
|
| 101 |
|
| 102 |
def _call_worker_sync(self, payload: dict) -> dict:
|
| 103 |
"""Synchronous POST to worker /chat endpoint."""
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
async def _stream_worker_sse(self, payload: dict) -> AsyncGenerator[dict, None]:
|
| 113 |
"""Streaming POST to worker /chat, yields parsed SSE events."""
|
|
|
|
| 1 |
"""Model manager that proxies inference requests to a Cloudflare Worker."""
|
| 2 |
import asyncio
|
| 3 |
+
import http.client
|
| 4 |
import json
|
| 5 |
import platform
|
| 6 |
+
import socket
|
| 7 |
+
import ssl
|
| 8 |
from datetime import datetime
|
| 9 |
from typing import Any, AsyncGenerator, Dict, List, Optional
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
|
| 12 |
import httpx
|
| 13 |
|
|
|
|
| 77 |
# Worker communication #
|
| 78 |
# ------------------------------------------------------------------ #
|
| 79 |
|
| 80 |
+
@staticmethod
|
| 81 |
+
def _should_try_dns_fallback(exc: Exception) -> bool:
|
| 82 |
+
message = str(exc).lower()
|
| 83 |
+
indicators = [
|
| 84 |
+
"no address associated with hostname",
|
| 85 |
+
"name or service not known",
|
| 86 |
+
"temporary failure in name resolution",
|
| 87 |
+
"getaddrinfo failed",
|
| 88 |
+
"nodename nor servname provided",
|
| 89 |
+
]
|
| 90 |
+
return any(indicator in message for indicator in indicators)
|
| 91 |
+
|
| 92 |
+
def _resolve_worker_ips(self) -> List[str]:
|
| 93 |
+
parsed = urlparse(self.worker_url)
|
| 94 |
+
hostname = parsed.hostname or ""
|
| 95 |
+
if not hostname:
|
| 96 |
+
return []
|
| 97 |
+
|
| 98 |
+
doh_endpoints = [
|
| 99 |
+
(
|
| 100 |
+
"https://cloudflare-dns.com/dns-query",
|
| 101 |
+
{"name": hostname, "type": "A"},
|
| 102 |
+
{"accept": "application/dns-json"},
|
| 103 |
+
),
|
| 104 |
+
(
|
| 105 |
+
"https://dns.google/resolve",
|
| 106 |
+
{"name": hostname, "type": "A"},
|
| 107 |
+
{},
|
| 108 |
+
),
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
for endpoint, params, headers in doh_endpoints:
|
| 112 |
+
try:
|
| 113 |
+
with httpx.Client(timeout=10, trust_env=False, headers=headers) as client:
|
| 114 |
+
response = client.get(endpoint, params=params)
|
| 115 |
+
response.raise_for_status()
|
| 116 |
+
payload = response.json()
|
| 117 |
+
answers = payload.get("Answer") or []
|
| 118 |
+
ips = [
|
| 119 |
+
str(item.get("data") or "").strip()
|
| 120 |
+
for item in answers
|
| 121 |
+
if str(item.get("type") or "") == "1" and str(item.get("data") or "").strip()
|
| 122 |
+
]
|
| 123 |
+
if ips:
|
| 124 |
+
return ips
|
| 125 |
+
except Exception:
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
def _request_worker_via_resolved_ip(self, method: str, path: str, payload: Optional[dict], timeout: int) -> dict:
|
| 131 |
+
parsed = urlparse(self.worker_url)
|
| 132 |
+
hostname = parsed.hostname or ""
|
| 133 |
+
if not hostname:
|
| 134 |
+
raise RuntimeError("WORKER_URL hostname is empty")
|
| 135 |
+
|
| 136 |
+
port = parsed.port or (443 if (parsed.scheme or "https") == "https" else 80)
|
| 137 |
+
base_path = (parsed.path or "").rstrip("/")
|
| 138 |
+
request_path = f"{base_path}{path}" or "/"
|
| 139 |
+
body = b""
|
| 140 |
+
headers = {
|
| 141 |
+
"Host": hostname,
|
| 142 |
+
"Accept": "application/json",
|
| 143 |
+
"Connection": "close",
|
| 144 |
+
"User-Agent": "gakrchat-backend/1.0",
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
if payload is not None:
|
| 148 |
+
body = json.dumps(payload).encode("utf-8")
|
| 149 |
+
headers["Content-Type"] = "application/json"
|
| 150 |
+
headers["Content-Length"] = str(len(body))
|
| 151 |
+
|
| 152 |
+
request_lines = [f"{method} {request_path} HTTP/1.1"]
|
| 153 |
+
request_lines.extend(f"{key}: {value}" for key, value in headers.items())
|
| 154 |
+
request_bytes = ("\r\n".join(request_lines) + "\r\n\r\n").encode("utf-8") + body
|
| 155 |
+
|
| 156 |
+
ips = self._resolve_worker_ips()
|
| 157 |
+
if not ips:
|
| 158 |
+
raise RuntimeError(f"Unable to resolve worker hostname via DNS-over-HTTPS: {hostname}")
|
| 159 |
+
|
| 160 |
+
ssl_context = ssl.create_default_context()
|
| 161 |
+
last_error: Optional[Exception] = None
|
| 162 |
+
|
| 163 |
+
for ip_address in ips:
|
| 164 |
+
try:
|
| 165 |
+
with socket.create_connection((ip_address, port), timeout=timeout) as raw_socket:
|
| 166 |
+
with ssl_context.wrap_socket(raw_socket, server_hostname=hostname) as tls_socket:
|
| 167 |
+
tls_socket.settimeout(timeout)
|
| 168 |
+
tls_socket.sendall(request_bytes)
|
| 169 |
+
response = http.client.HTTPResponse(tls_socket)
|
| 170 |
+
response.begin()
|
| 171 |
+
response_body = response.read()
|
| 172 |
+
|
| 173 |
+
if response.status >= 400:
|
| 174 |
+
body_preview = response_body.decode("utf-8", errors="replace")[:500]
|
| 175 |
+
raise httpx.HTTPStatusError(
|
| 176 |
+
f"Worker returned status {response.status}: {body_preview}",
|
| 177 |
+
request=None,
|
| 178 |
+
response=None,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
return json.loads(response_body.decode("utf-8"))
|
| 183 |
+
except json.JSONDecodeError as exc:
|
| 184 |
+
raise RuntimeError("Worker returned non-JSON response during DNS fallback") from exc
|
| 185 |
+
except Exception as exc:
|
| 186 |
+
last_error = exc
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
raise RuntimeError(f"Worker DNS fallback failed: {last_error}")
|
| 190 |
+
|
| 191 |
def load_model(self, force_reload: bool = False) -> bool:
|
| 192 |
"""Verify that the Cloudflare Worker is healthy and reachable."""
|
| 193 |
if self._worker_healthy and not force_reload:
|
|
|
|
| 205 |
return True
|
| 206 |
self._last_error = f"Worker returned status: {data.get('status')}"
|
| 207 |
except Exception as exc:
|
| 208 |
+
if self._should_try_dns_fallback(exc):
|
| 209 |
+
try:
|
| 210 |
+
data = self._request_worker_via_resolved_ip("GET", "/health", None, timeout=10)
|
| 211 |
+
if data.get("status") == "ok":
|
| 212 |
+
self._worker_healthy = True
|
| 213 |
+
self._worker_model_name = data.get("model")
|
| 214 |
+
self._last_error = None
|
| 215 |
+
print(f"Worker healthy via DNS fallback: model={self._worker_model_name}")
|
| 216 |
+
return True
|
| 217 |
+
self._last_error = f"Worker returned status: {data.get('status')}"
|
| 218 |
+
except Exception as fallback_exc:
|
| 219 |
+
self._last_error = f"Worker health check failed: {exc}; DNS fallback failed: {fallback_exc}"
|
| 220 |
+
self._worker_healthy = False
|
| 221 |
+
else:
|
| 222 |
+
self._last_error = f"Worker health check failed: {exc}"
|
| 223 |
+
self._worker_healthy = False
|
| 224 |
return False
|
| 225 |
|
| 226 |
def unload_model(self):
|
|
|
|
| 230 |
|
| 231 |
def _call_worker_sync(self, payload: dict) -> dict:
|
| 232 |
"""Synchronous POST to worker /chat endpoint."""
|
| 233 |
+
try:
|
| 234 |
+
with httpx.Client(timeout=180, trust_env=False) as client:
|
| 235 |
+
resp = client.post(f"{self.worker_url}/chat", json=payload)
|
| 236 |
+
if resp.status_code >= 400:
|
| 237 |
+
body = resp.text[:500]
|
| 238 |
+
print(f"[WORKER] {resp.status_code} from {self.worker_url}/chat: {body}")
|
| 239 |
+
resp.raise_for_status()
|
| 240 |
+
return resp.json()
|
| 241 |
+
except Exception as exc:
|
| 242 |
+
if self._should_try_dns_fallback(exc):
|
| 243 |
+
return self._request_worker_via_resolved_ip("POST", "/chat", payload, timeout=180)
|
| 244 |
+
raise
|
| 245 |
|
| 246 |
async def _stream_worker_sse(self, payload: dict) -> AsyncGenerator[dict, None]:
|
| 247 |
"""Streaming POST to worker /chat, yields parsed SSE events."""
|