hasari-api / services /backend /security.py
erdoganpeker's picture
v0.3.0 — multimodal vehicle damage MVP
e327f0d
"""
backend/security.py
-------------------
Security primitives for arac-hasar-v2.
Owned by: Security Engineer.
Other backend files MUST import from here rather than rolling their own
crypto / JWT / API key logic.
Exports
=======
Password hashing
- hash_password(password) -> str
- verify_password(plain, hashed) -> bool
JWT
- create_access_token(user_id, role, expires_minutes=30) -> str
- create_refresh_token(user_id, expires_days=7) -> str
- verify_token(token, expected_type="access") -> TokenPayload
API keys (for service-to-service / pilot integrations)
- generate_api_key() -> tuple[plain_key, hash_to_store]
- verify_api_key(plain_key, stored_hash) -> bool
FastAPI dependencies
- require_user -> TokenPayload (any authenticated user)
- require_admin -> TokenPayload (role == "admin")
File upload validators
- sniff_image_mime(buf) -> str | None
- validate_image_upload(buf, max_size_mb=20, max_w=10000, max_h=10000)
-> ValidatedImage (decoded PIL image, EXIF stripped, orientation applied)
- sanitize_filename(name) -> str
Config (read from env at import time)
- JWT_SECRET_KEY (32+ chars, hard-fail if missing in non-dev)
- JWT_ALGORITHM (default HS256)
- ACCESS_TOKEN_MINUTES (default 30)
- REFRESH_TOKEN_DAYS (default 7)
- BCRYPT_ROUNDS (default 12)
- RATE_LIMIT_REDIS_URL (used by middleware)
- ALLOWED_ORIGINS (CSV, parsed into list)
- ENVIRONMENT (development | staging | production)
SQL injection note
==================
This codebase relies on SQLAlchemy parameterized queries / ORM. NEVER build
SQL via f-strings or .format(). If raw SQL is unavoidable, use
`text("... :param ...").bindparams(param=value)`.
"""
from __future__ import annotations
import hashlib
import hmac
import io
import logging
import os
import re
import secrets
import time
import unicodedata
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Final, Literal, Optional, Tuple
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import bcrypt as _bcrypt
from jose import JWTError, jwt
from PIL import Image, ImageOps, UnidentifiedImageError
from pydantic import BaseModel, Field
log = logging.getLogger("backend.security")
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
ENVIRONMENT: Final[str] = os.getenv("ENVIRONMENT", "development").lower()
JWT_SECRET_KEY: Final[str] = os.getenv("JWT_SECRET_KEY", "")
JWT_ALGORITHM: Final[str] = os.getenv("JWT_ALGORITHM", "HS256")
ACCESS_TOKEN_MINUTES: Final[int] = int(os.getenv("ACCESS_TOKEN_MINUTES", "15"))
REFRESH_TOKEN_DAYS: Final[int] = int(os.getenv("REFRESH_TOKEN_DAYS", "7"))
BCRYPT_ROUNDS: Final[int] = int(os.getenv("BCRYPT_ROUNDS", "12"))
RATE_LIMIT_REDIS_URL: Final[str] = os.getenv(
"RATE_LIMIT_REDIS_URL", os.getenv("REDIS_URL", "redis://redis:6379/1")
)
ALLOWED_ORIGINS: Final[list[str]] = [
s.strip() for s in os.getenv("ALLOWED_ORIGINS", "").split(",") if s.strip()
]
def _validate_config() -> None:
"""Hard-fail on insecure config in staging/production."""
if ENVIRONMENT in ("staging", "production"):
if len(JWT_SECRET_KEY) < 32:
raise RuntimeError(
"JWT_SECRET_KEY must be at least 32 characters in non-dev environments"
)
if JWT_ALGORITHM != "HS256" and not JWT_ALGORITHM.startswith(("HS", "RS", "ES")):
raise RuntimeError(f"Unsupported JWT_ALGORITHM: {JWT_ALGORITHM}")
elif len(JWT_SECRET_KEY) < 32:
log.warning(
"JWT_SECRET_KEY is short or unset; OK only for local dev. "
"Set a 32+ char secret before staging/production."
)
_validate_config()
# Effective secret for dev convenience (random per-process if unset).
_EFFECTIVE_JWT_SECRET = JWT_SECRET_KEY or secrets.token_urlsafe(48)
# ---------------------------------------------------------------------------
# Password hashing (bcrypt, cost 12)
# ---------------------------------------------------------------------------
def hash_password(password: str) -> str:
"""Bcrypt hash a plaintext password. Cost factor from BCRYPT_ROUNDS (default 12).
Note: uses the `bcrypt` package directly. We dropped passlib because
passlib 1.7.4 (unmaintained since 2020) mis-detects bcrypt >= 4.x and
raises a spurious "password cannot be longer than 72 bytes" error for
every input.
"""
if not isinstance(password, str) or not password:
raise ValueError("password must be a non-empty string")
pw_bytes = password.encode("utf-8")
# bcrypt has a 72-byte limit; reject overly long inputs to avoid silent truncation.
if len(pw_bytes) > 72:
raise ValueError("password exceeds 72 bytes (bcrypt limit)")
salt = _bcrypt.gensalt(rounds=BCRYPT_ROUNDS)
return _bcrypt.hashpw(pw_bytes, salt).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
"""Constant-time bcrypt verification. Never logs the inputs."""
if not plain or not hashed:
return False
try:
return _bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
except (ValueError, TypeError):
return False
# ---------------------------------------------------------------------------
# JWT
# ---------------------------------------------------------------------------
TokenType = Literal["access", "refresh"]
class TokenPayload(BaseModel):
"""Decoded JWT claims surfaced to handlers."""
sub: str = Field(..., description="user_id as string")
role: str = Field(default="user")
type: TokenType = Field(default="access")
iat: int
exp: int
jti: str
@property
def user_id(self) -> UUID:
return UUID(self.sub)
def _build_token(
*,
user_id: UUID,
role: str,
token_type: TokenType,
lifetime: timedelta,
) -> str:
now = datetime.now(tz=timezone.utc)
payload = {
"sub": str(user_id),
"role": role,
"type": token_type,
"iat": int(now.timestamp()),
"exp": int((now + lifetime).timestamp()),
"jti": secrets.token_urlsafe(16),
"iss": "arac-hasar-v2",
}
return jwt.encode(payload, _EFFECTIVE_JWT_SECRET, algorithm=JWT_ALGORITHM)
def create_access_token(
user_id: UUID, role: str, expires_minutes: int = ACCESS_TOKEN_MINUTES
) -> str:
"""Short-lived access token. Default 30 minutes."""
return _build_token(
user_id=user_id,
role=role,
token_type="access",
lifetime=timedelta(minutes=expires_minutes),
)
def create_refresh_token(user_id: UUID, expires_days: int = REFRESH_TOKEN_DAYS) -> str:
"""Long-lived refresh token. Default 7 days. Role is intentionally 'user' --
privilege must be re-checked from DB on refresh."""
return _build_token(
user_id=user_id,
role="user",
token_type="refresh",
lifetime=timedelta(days=expires_days),
)
def verify_token(token: str, expected_type: str = "access") -> TokenPayload:
"""Decode + validate JWT. Raises 401 on any failure (expired, bad sig, wrong type)."""
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="missing token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
claims = jwt.decode(
token,
_EFFECTIVE_JWT_SECRET,
algorithms=[JWT_ALGORITHM],
options={"require": ["exp", "iat", "sub", "jti"]},
)
except JWTError as e:
log.info("jwt.verify.fail reason=%s", type(e).__name__)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
) from None
try:
payload = TokenPayload(**claims)
except Exception:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="malformed token payload",
) from None
if payload.type != expected_type:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"wrong token type (expected {expected_type})",
)
return payload
# ---------------------------------------------------------------------------
# API keys (for pilot integrations / service-to-service)
# ---------------------------------------------------------------------------
_API_KEY_PREFIX = "ahv2_"
_API_KEY_BYTES = 32 # 256 bits of entropy
def generate_api_key() -> Tuple[str, str]:
"""
Generate a new API key. Returns (plain_key, hash_to_store).
plain_key is shown to the user EXACTLY ONCE. Store only the hash.
Hash is SHA-256 (fast, deterministic) -- API keys are already 256 bits
of entropy, so bcrypt-style slow hashing is unnecessary.
"""
raw = secrets.token_urlsafe(_API_KEY_BYTES)
plain_key = f"{_API_KEY_PREFIX}{raw}"
stored_hash = hashlib.sha256(plain_key.encode("utf-8")).hexdigest()
return plain_key, stored_hash
def verify_api_key(plain_key: str, stored_hash: str) -> bool:
"""Constant-time comparison of plain key against stored sha256 hash."""
if not plain_key or not stored_hash:
return False
if not plain_key.startswith(_API_KEY_PREFIX):
return False
computed = hashlib.sha256(plain_key.encode("utf-8")).hexdigest()
return hmac.compare_digest(computed, stored_hash)
# ---------------------------------------------------------------------------
# FastAPI auth dependencies
# ---------------------------------------------------------------------------
_bearer = HTTPBearer(auto_error=False, description="JWT access token")
def _extract_token(
creds: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
) -> str:
if creds is None or not creds.credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="missing bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
if creds.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid auth scheme",
headers={"WWW-Authenticate": "Bearer"},
)
return creds.credentials
def require_user(token: str = Depends(_extract_token)) -> TokenPayload:
"""FastAPI dependency: any authenticated user (access token)."""
return verify_token(token, expected_type="access")
def require_admin(payload: TokenPayload = Depends(require_user)) -> TokenPayload:
"""FastAPI dependency: role == admin."""
if payload.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="admin privileges required",
)
return payload
# ---------------------------------------------------------------------------
# File upload validation
# ---------------------------------------------------------------------------
# Magic-byte signatures. We do NOT trust client-provided Content-Type or filename.
_MAGIC_SIGS: tuple[tuple[bytes, str], ...] = (
(b"\xff\xd8\xff", "image/jpeg"),
(b"\x89PNG\r\n\x1a\n", "image/png"),
# WEBP: "RIFF????WEBP" -- handled below
)
ALLOWED_IMAGE_MIMES: frozenset[str] = frozenset({"image/jpeg", "image/png", "image/webp"})
DEFAULT_MAX_UPLOAD_MB: Final[int] = 20
DEFAULT_MAX_DIM: Final[int] = 10_000
def sniff_image_mime(buf: bytes) -> Optional[str]:
"""Return the sniffed MIME type for the buffer, or None if unrecognized."""
if not buf or len(buf) < 12:
return None
for sig, mime in _MAGIC_SIGS:
if buf.startswith(sig):
return mime
# WEBP requires checking bytes 0..3 and 8..11
if buf[0:4] == b"RIFF" and buf[8:12] == b"WEBP":
return "image/webp"
return None
@dataclass(frozen=True)
class ValidatedImage:
"""Result of validate_image_upload: decoded image + sanitized metadata."""
image: Image.Image # PIL image, EXIF stripped, orientation applied
mime: str
width: int
height: int
size_bytes: int
sha256: str
def validate_image_upload(
buf: bytes,
*,
max_size_mb: int = DEFAULT_MAX_UPLOAD_MB,
max_w: int = DEFAULT_MAX_DIM,
max_h: int = DEFAULT_MAX_DIM,
) -> ValidatedImage:
"""
Validate and normalize an image upload.
Steps:
1. Size limit (default 20 MB).
2. MIME sniff by magic bytes (jpeg/png/webp only).
3. Decode with PIL. Reject on UnidentifiedImageError / DecompressionBombError.
4. Apply EXIF orientation (so downstream ML sees the correct rotation).
5. Strip EXIF metadata (PII: GPS, camera serial, timestamps).
6. Dimension limits.
Returns a ValidatedImage. Raises HTTPException(400/413) on rejection.
"""
size_bytes = len(buf)
max_bytes = max_size_mb * 1024 * 1024
if size_bytes == 0:
raise HTTPException(status_code=400, detail="empty upload")
if size_bytes > max_bytes:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"file exceeds {max_size_mb} MB limit",
)
mime = sniff_image_mime(buf)
if mime not in ALLOWED_IMAGE_MIMES:
raise HTTPException(
status_code=400,
detail="unsupported image type (allowed: jpeg, png, webp)",
)
# Guard against decompression-bomb DoS.
Image.MAX_IMAGE_PIXELS = max_w * max_h
try:
probe = Image.open(io.BytesIO(buf))
probe.verify() # cheap structural check
img = Image.open(io.BytesIO(buf)) # reopen: verify() consumes the stream
img = ImageOps.exif_transpose(img) # apply orientation, then we'll drop EXIF
except (UnidentifiedImageError, Image.DecompressionBombError, OSError, ValueError):
raise HTTPException(status_code=400, detail="invalid or corrupt image") from None
if img.width > max_w or img.height > max_h:
raise HTTPException(
status_code=400,
detail=f"image dimensions exceed {max_w}x{max_h}",
)
# Strip EXIF: create a clean copy without the info dict.
clean = Image.new(img.mode, img.size)
clean.putdata(list(img.getdata()))
digest = hashlib.sha256(buf).hexdigest()
return ValidatedImage(
image=clean,
mime=mime,
width=img.width,
height=img.height,
size_bytes=size_bytes,
sha256=digest,
)
# Filename sanitizer ---------------------------------------------------------
_UNSAFE_FILENAME_RE = re.compile(r"[^A-Za-z0-9._-]+")
def sanitize_filename(name: str, *, fallback_ext: str = "bin") -> str:
"""
Produce a safe filename for S3/storage.
Defenses:
- No path components (basename only).
- Strip null bytes and control chars.
- Unicode normalize NFKD then drop non-ASCII.
- Whitelist [A-Za-z0-9._-]; everything else -> "_".
- Reject reserved Windows names just in case (CON, PRN, AUX, NUL, COM1..LPT9).
- Cap length at 120 chars and ALWAYS prefix with a fresh uuid4 to prevent
collisions and obscure user-controlled content from logs.
"""
if not isinstance(name, str):
name = "upload"
# basename only -- defeats ../, absolute paths, backslashes.
name = os.path.basename(name.replace("\\", "/"))
# Strip nulls/control chars before normalization.
name = name.replace("\x00", "")
name = "".join(ch for ch in name if ch.isprintable())
# Unicode normalize and drop non-ASCII.
name = unicodedata.normalize("NFKD", name).encode("ascii", "ignore").decode("ascii")
# Whitelist.
name = _UNSAFE_FILENAME_RE.sub("_", name).strip("._-") or f"upload.{fallback_ext}"
# Reserved Windows names (belt + suspenders).
stem = name.split(".", 1)[0].upper()
if stem in {"CON", "PRN", "AUX", "NUL"} or re.fullmatch(r"(COM|LPT)[1-9]", stem):
name = f"_{name}"
if len(name) > 120:
# keep extension
root, dot, ext = name.rpartition(".")
if dot and len(ext) <= 8:
name = f"{root[: 120 - len(ext) - 1]}.{ext}"
else:
name = name[:120]
return f"{uuid.uuid4().hex}_{name}"
# ---------------------------------------------------------------------------
# Helpers re-exported for middleware / handlers
# ---------------------------------------------------------------------------
def utcnow_ms() -> int:
"""Monotonic-ish wall-clock millisecond timestamp for access logs."""
return int(time.time() * 1000)
__all__ = [
# config
"ENVIRONMENT",
"JWT_ALGORITHM",
"ACCESS_TOKEN_MINUTES",
"REFRESH_TOKEN_DAYS",
"BCRYPT_ROUNDS",
"RATE_LIMIT_REDIS_URL",
"ALLOWED_ORIGINS",
# password
"hash_password",
"verify_password",
# jwt
"TokenPayload",
"create_access_token",
"create_refresh_token",
"verify_token",
# api keys
"generate_api_key",
"verify_api_key",
# deps
"require_user",
"require_admin",
# uploads
"ValidatedImage",
"ALLOWED_IMAGE_MIMES",
"sniff_image_mime",
"validate_image_upload",
"sanitize_filename",
# misc
"utcnow_ms",
]