|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import time |
|
|
import logging |
|
|
import json |
|
|
import asyncio |
|
|
from typing import Callable, Optional |
|
|
|
|
|
from anyio import EndOfStream |
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from starlette.responses import Response, JSONResponse |
|
|
from starlette.middleware.gzip import GZipMiddleware |
|
|
from starlette.exceptions import ClientDisconnect |
|
|
|
|
|
|
|
|
try: |
|
|
from pythonjsonlogger import jsonlogger |
|
|
_HAS_PY_JSON_LOGGER = True |
|
|
except Exception: |
|
|
_HAS_PY_JSON_LOGGER = False |
|
|
|
|
|
from .deps import get_settings |
|
|
from .core.rate_limit import RateLimiter |
|
|
from .core.logging import add_trace_id |
|
|
|
|
|
class _SimpleJsonFormatter(logging.Formatter): |
|
|
def format(self, record: logging.LogRecord) -> str: |
|
|
payload = { |
|
|
"asctime": self.formatTime(record, "%Y-%m-%d %H:%M:%S"), |
|
|
"name": record.name, |
|
|
"levelname": record.levelname, |
|
|
"message": record.getMessage(), |
|
|
"trace_id": getattr(record, "trace_id", None), |
|
|
} |
|
|
try: |
|
|
return json.dumps(payload, ensure_ascii=False) |
|
|
except Exception: |
|
|
return ( |
|
|
f'{payload["asctime"]} {payload["name"]} {payload["levelname"]} ' |
|
|
f'{payload["message"]} trace_id={payload["trace_id"]}' |
|
|
) |
|
|
|
|
|
_logger = logging.getLogger("matrix-ai") |
|
|
if not _logger.handlers: |
|
|
_logger.setLevel(logging.INFO) |
|
|
_handler = logging.StreamHandler() |
|
|
if _HAS_PY_JSON_LOGGER: |
|
|
_formatter = jsonlogger.JsonFormatter( |
|
|
"%(asctime)s %(name)s %(levelname)s %(message)s %(trace_id)s" |
|
|
) |
|
|
else: |
|
|
_formatter = _SimpleJsonFormatter() |
|
|
logging.getLogger("uvicorn.error").warning( |
|
|
"python-json-logger not found; using a minimal JSON formatter." |
|
|
) |
|
|
_handler.setFormatter(_formatter) |
|
|
_logger.addHandler(_handler) |
|
|
|
|
|
_rate_limiter = RateLimiter() |
|
|
_SSE_PATH_SUFFIXES = ("/chat/stream", "/v1/chat/stream") |
|
|
_HEALTH_PATHS = ("/health", "/livez", "/readyz") |
|
|
|
|
|
def _client_ip(request: Request) -> str: |
|
|
xff = request.headers.get("x-forwarded-for") |
|
|
if xff: |
|
|
return xff.split(",")[0].strip() |
|
|
return request.client.host if request.client else "unknown" |
|
|
|
|
|
def _is_sse(request: Request, response: Optional[Response] = None) -> bool: |
|
|
path = request.url.path |
|
|
if path.endswith(_SSE_PATH_SUFFIXES): |
|
|
return True |
|
|
if response is not None: |
|
|
ctype = response.headers.get("content-type", "") |
|
|
if ctype.startswith("text/event-stream"): |
|
|
return True |
|
|
accept = request.headers.get("accept", "") |
|
|
return "text/event-stream" in accept |
|
|
|
|
|
def attach_middlewares(app: FastAPI) -> None: |
|
|
app.add_middleware(GZipMiddleware, minimum_size=512) |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
expose_headers=["X-Trace-Id", "X-Process-Time-Ms", "Server-Timing"], |
|
|
) |
|
|
|
|
|
@app.middleware("http") |
|
|
async def rate_limit_and_log_middleware(request: Request, call_next: Callable): |
|
|
add_trace_id(request) |
|
|
trace_id = getattr(request.state, "trace_id", "N/A") |
|
|
|
|
|
path = request.url.path |
|
|
method = request.method |
|
|
ua = request.headers.get("user-agent", "-") |
|
|
ip = _client_ip(request) |
|
|
|
|
|
if path in _HEALTH_PATHS: |
|
|
try: |
|
|
response = await call_next(request) |
|
|
except Exception: |
|
|
return JSONResponse({"status": "unhealthy"}, status_code=500) |
|
|
response.headers.setdefault("X-Trace-Id", str(trace_id)) |
|
|
return response |
|
|
|
|
|
settings = get_settings() |
|
|
if not _rate_limiter.allow(ip, path, settings.limits.rate_per_min): |
|
|
_logger.warning( |
|
|
"429 Too Many Requests from %s on %s", |
|
|
ip, path, extra={"trace_id": trace_id}, |
|
|
) |
|
|
return JSONResponse({"detail": "Too Many Requests"}, status_code=429, |
|
|
headers={"X-Trace-Id": str(trace_id)}) |
|
|
|
|
|
t0 = time.time() |
|
|
try: |
|
|
response = await call_next(request) |
|
|
|
|
|
|
|
|
except (EndOfStream, ClientDisconnect, asyncio.CancelledError): |
|
|
_logger.info( |
|
|
"Client disconnected from stream. Path: %s, IP: %s", |
|
|
path, ip, extra={"trace_id": trace_id}, |
|
|
) |
|
|
resp = Response(status_code=204) |
|
|
resp.headers.setdefault("X-Trace-Id", str(trace_id)) |
|
|
return resp |
|
|
|
|
|
except RuntimeError as e: |
|
|
|
|
|
if str(e) == "No response returned.": |
|
|
_logger.info( |
|
|
"Downstream produced no response (likely streaming disconnect). " |
|
|
"Path: %s, IP: %s", |
|
|
path, ip, extra={"trace_id": trace_id}, |
|
|
) |
|
|
resp = Response(status_code=204) |
|
|
resp.headers.setdefault("X-Trace-Id", str(trace_id)) |
|
|
return resp |
|
|
|
|
|
raise |
|
|
|
|
|
except Exception as e: |
|
|
_logger.exception( |
|
|
"Unhandled error while processing %s %s: %s", |
|
|
method, path, e, extra={"trace_id": trace_id}, |
|
|
) |
|
|
dur_ms = (time.time() - t0) * 1000.0 |
|
|
return JSONResponse( |
|
|
{"detail": "Internal Server Error"}, |
|
|
status_code=500, |
|
|
headers={ |
|
|
"X-Trace-Id": str(trace_id), |
|
|
"X-Process-Time-Ms": f"{dur_ms:.2f}", |
|
|
"Server-Timing": f"app;dur={dur_ms:.2f}", |
|
|
}, |
|
|
) |
|
|
|
|
|
if not isinstance(response, Response): |
|
|
_logger.error("Downstream returned no Response object for %s", |
|
|
path, extra={"trace_id": trace_id}) |
|
|
return JSONResponse({"detail": "Internal Server Error"}, |
|
|
status_code=500, |
|
|
headers={"X-Trace-Id": str(trace_id)}) |
|
|
|
|
|
sse = _is_sse(request, response) |
|
|
dur_ms = (time.time() - t0) * 1000.0 |
|
|
response.headers.setdefault("X-Trace-Id", str(trace_id)) |
|
|
response.headers.setdefault("X-Process-Time-Ms", f"{dur_ms:.2f}") |
|
|
response.headers.setdefault("Server-Timing", f"app;dur={dur_ms:.2f}") |
|
|
|
|
|
if sse: |
|
|
response.headers.setdefault("Cache-Control", "no-cache") |
|
|
_logger.info( |
|
|
'"%s %s" %s (SSE) ip=%s ua="%s" %.2fms', |
|
|
method, path, response.status_code, ip, ua, dur_ms, |
|
|
extra={"trace_id": trace_id}, |
|
|
) |
|
|
return response |
|
|
|
|
|
_logger.info( |
|
|
'"%s %s" %s ip=%s ua="%s" %.2fms', |
|
|
method, path, response.status_code, ip, ua, dur_ms, |
|
|
extra={"trace_id": trace_id}, |
|
|
) |
|
|
return response |
|
|
|