File size: 5,665 Bytes
0853b44
fba30db
 
0853b44
 
 
 
 
fba30db
 
 
 
26f3f24
0853b44
 
 
 
 
fba30db
0853b44
 
 
fba30db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f3f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fba30db
 
 
 
 
 
 
26f3f24
fba30db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0853b44
 
 
 
 
 
 
 
 
 
 
fba30db
0853b44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fba30db
 
 
 
 
 
26f3f24
 
fba30db
 
 
 
0853b44
 
 
 
fba30db
 
0853b44
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import asyncio
import secrets
import sys
from contextlib import asynccontextmanager

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, RedirectResponse

from api.router import api_router
from config import settings
from db.database import init_db
from models.model_loader import get_model_loader
from services.rate_limit import RateLimitContextMiddleware, limiter
from services.report_service import cleanup_expired


class ContentLengthLimitMiddleware(BaseHTTPMiddleware):
    """Reject oversized uploads via Content-Length header before reading body.
    Saves bandwidth + memory vs letting read_upload_bytes reject post-read."""

    def __init__(self, app, max_bytes: int) -> None:
        super().__init__(app)
        self._max = max_bytes

    async def dispatch(self, request, call_next):
        cl = request.headers.get("content-length")
        if cl and cl.isdigit() and int(cl) > self._max:
            return JSONResponse(
                status_code=413,
                content={"detail": f"Upload exceeds {self._max // (1024 * 1024)} MB limit"},
            )
        return await call_next(request)


class HTTPSRedirectAndHSTSMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request, call_next):
        if not settings.DEBUG:
            forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
            host = request.headers.get("host", "").split(":", 1)[0].lower()
            if forwarded_proto != "https" and request.url.scheme != "https" and host not in {"127.0.0.1", "localhost", ""}:
                https_url = request.url.replace(scheme="https")
                return RedirectResponse(str(https_url), status_code=308)

        response = await call_next(request)
        if not settings.DEBUG:
            response.headers.setdefault("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
        return response


# === Phase 15.3 — JWT / CORS / logging hardening ===

_DEFAULT_JWT_SECRET = "change-me-in-production"


def _enforce_production_hardening() -> None:
    """Refuse to start in production with unsafe defaults (Phase 15.3)."""
    if settings.JWT_SECRET_KEY == _DEFAULT_JWT_SECRET or not settings.JWT_SECRET_KEY or settings.JWT_SECRET_KEY_GENERATED:
        example = secrets.token_urlsafe(48)
        if settings.DEBUG:
            logger.warning(
                "JWT_SECRET_KEY is unset or default — safe in dev only. "
                f"Set it before deploying. Example: {example}"
            )
        else:
            logger.error(
                "Refusing to start: JWT_SECRET_KEY is unset or default. "
                f"Set JWT_SECRET_KEY in your environment. Example: {example}"
            )
            sys.exit(1)
    if "*" in settings.CORS_ORIGINS and not settings.DEBUG:
        logger.error(
            "Refusing to start: CORS_ORIGINS contains '*' while allow_credentials=True. "
            "Set an explicit origin list."
        )
        sys.exit(1)


def _configure_logging() -> None:
    """Rotate + retain logs, scrub emails."""
    import re

    email_re = re.compile(r"([A-Za-z0-9._%+-]+)@([A-Za-z0-9.-]+\.[A-Za-z]{2,})")

    def _scrub(record):
        msg = record["message"]
        record["message"] = email_re.sub(r"***@\2", msg)
        return True

    logger.remove()
    logger.add(sys.stderr, filter=_scrub, level="INFO")
    logger.add(
        "logs/deepshield.log",
        rotation="10 MB",
        retention="7 days",
        filter=_scrub,
        level="INFO",
        enqueue=True,
    )


_configure_logging()


async def _report_cleanup_loop():
    while True:
        try:
            cleanup_expired()
        except Exception as e:  # noqa: BLE001
            logger.warning(f"Report cleanup error: {e}")
        await asyncio.sleep(600)  # every 10 min


@asynccontextmanager
async def lifespan(app: FastAPI):
    _enforce_production_hardening()
    logger.info("Starting DeepShield backend")
    init_db()
    logger.info("Database initialized")
    if settings.PRELOAD_MODELS:
        get_model_loader().preload_phase1()
    else:
        logger.info("PRELOAD_MODELS=false — models will load on first use")
    task = asyncio.create_task(_report_cleanup_loop())
    yield
    task.cancel()
    logger.info("Shutting down DeepShield backend")


app = FastAPI(
    title="DeepShield API",
    description="Explainable AI-based multimodal misinformation detection",
    version="0.1.0",
    lifespan=lifespan,
)

# Phase 15.2 — slowapi rate limiter
app.state.limiter = limiter


app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(RateLimitContextMiddleware)
# Phase 15.3 — enforce HTTPS in production and add HSTS
app.add_middleware(HTTPSRedirectAndHSTSMiddleware)
# Phase 15.3 — reject oversized uploads before reading body
app.add_middleware(ContentLengthLimitMiddleware, max_bytes=settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024)

# Phase 15.3 — explicit CORS methods/headers (no wildcards with credentials)
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.CORS_ORIGINS,
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
    allow_headers=["Authorization", "Content-Type", "Accept", "Origin", "X-Requested-With"],
)

app.include_router(api_router)


@app.get("/")
def root():
    return {"service": "DeepShield", "docs": "/docs", "health": "/api/v1/health"}