deepshield / main.py
ar07xd's picture
Sync from GitHub via hub-sync
b1d2ce2 verified
import asyncio
import secrets
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, RedirectResponse
from api.router import api_router
from config import settings
from db.database import init_db
from models.model_loader import get_model_loader
from services.rate_limit import RateLimitContextMiddleware, limiter
from services.report_service import cleanup_expired
class ContentLengthLimitMiddleware(BaseHTTPMiddleware):
"""Reject oversized uploads via Content-Length header before reading body.
Saves bandwidth + memory vs letting read_upload_bytes reject post-read."""
def __init__(self, app, max_bytes: int) -> None:
super().__init__(app)
self._max = max_bytes
async def dispatch(self, request, call_next):
cl = request.headers.get("content-length")
if cl and cl.isdigit() and int(cl) > self._max:
return JSONResponse(
status_code=413,
content={"detail": f"Upload exceeds {self._max // (1024 * 1024)} MB limit"},
)
return await call_next(request)
class HTTPSRedirectAndHSTSMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
if not settings.DEBUG:
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
host = request.headers.get("host", "").split(":", 1)[0].lower()
if forwarded_proto != "https" and request.url.scheme != "https" and host not in {"127.0.0.1", "localhost", ""}:
https_url = request.url.replace(scheme="https")
return RedirectResponse(str(https_url), status_code=308)
response = await call_next(request)
if not settings.DEBUG:
response.headers.setdefault("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
return response
# === Phase 15.3 — JWT / CORS / logging hardening ===
_DEFAULT_JWT_SECRET = "change-me-in-production"
def _enforce_production_hardening() -> None:
"""Refuse to start in production with unsafe defaults (Phase 15.3)."""
if settings.JWT_SECRET_KEY == _DEFAULT_JWT_SECRET or not settings.JWT_SECRET_KEY or settings.JWT_SECRET_KEY_GENERATED:
example = secrets.token_urlsafe(48)
if settings.DEBUG:
logger.warning(
"JWT_SECRET_KEY is unset or default — safe in dev only. "
f"Set it before deploying. Example: {example}"
)
else:
logger.error(
"Refusing to start: JWT_SECRET_KEY is unset or default. "
f"Set JWT_SECRET_KEY in your environment. Example: {example}"
)
sys.exit(1)
if "*" in settings.CORS_ORIGINS and not settings.DEBUG:
logger.error(
"Refusing to start: CORS_ORIGINS contains '*' while allow_credentials=True. "
"Set an explicit origin list."
)
sys.exit(1)
def _configure_logging() -> None:
"""Rotate + retain logs, scrub emails."""
import re
email_re = re.compile(r"([A-Za-z0-9._%+-]+)@([A-Za-z0-9.-]+\.[A-Za-z]{2,})")
def _scrub(record):
msg = record["message"]
record["message"] = email_re.sub(r"***@\2", msg)
return True
logger.remove()
logger.add(sys.stderr, filter=_scrub, level="INFO")
logger.add(
"logs/deepshield.log",
rotation="10 MB",
retention="7 days",
filter=_scrub,
level="INFO",
enqueue=True,
)
_configure_logging()
async def _report_cleanup_loop():
while True:
try:
cleanup_expired()
except Exception as e: # noqa: BLE001
logger.warning(f"Report cleanup error: {e}")
await asyncio.sleep(600) # every 10 min
@asynccontextmanager
async def lifespan(app: FastAPI):
_enforce_production_hardening()
logger.info("Starting DeepShield backend")
init_db()
logger.info("Database initialized")
if settings.PRELOAD_MODELS:
get_model_loader().preload_phase1()
else:
logger.info("PRELOAD_MODELS=false — models will load on first use")
task = asyncio.create_task(_report_cleanup_loop())
yield
task.cancel()
logger.info("Shutting down DeepShield backend")
app = FastAPI(
title="DeepShield API",
description="Explainable AI-based multimodal misinformation detection",
version="0.1.0",
lifespan=lifespan,
)
# Phase 15.2 — slowapi rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(RateLimitContextMiddleware)
# Phase 15.3 — enforce HTTPS in production and add HSTS
app.add_middleware(HTTPSRedirectAndHSTSMiddleware)
# Phase 15.3 — reject oversized uploads before reading body
app.add_middleware(ContentLengthLimitMiddleware, max_bytes=settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024)
# Phase 15.3 — explicit CORS methods/headers (no wildcards with credentials)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type", "Accept", "Origin", "X-Requested-With"],
)
app.include_router(api_router)
@app.get("/")
def root():
return {"service": "DeepShield", "docs": "/docs", "health": "/api/v1/health"}