"""Shared FastAPI dependencies for ML module routes.""" from __future__ import annotations from functools import lru_cache from typing import Any, Dict, Optional from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel, Field from ml_module.core.auth import SupabaseAuthError, get_supabase_verifier from ml_module.core.exceptions import ProjectNotFoundException from ml_module.services.project_service import ProjectService class AuthenticatedUser(BaseModel): """Represents the authenticated Supabase user extracted from the JWT.""" user_id: str = Field(..., description="Supabase user identifier") email: Optional[str] = Field(default=None) role: Optional[str] = Field(default=None) claims: Dict[str, Any] = Field(default_factory=dict) _http_bearer = HTTPBearer(auto_error=False) async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(_http_bearer), ) -> AuthenticatedUser: if credentials is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization header missing") token = credentials.credentials verifier = get_supabase_verifier() try: claims = verifier.decode(token) except SupabaseAuthError as exc: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc user_id = str(claims.get("sub") or claims.get("user_id")) if not user_id: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Supabase token missing subject") app_metadata = claims.get("app_metadata") or {} role = app_metadata.get("role") or claims.get("role") return AuthenticatedUser( user_id=user_id, email=claims.get("email"), role=role, claims=claims, ) def ensure_project_access( project_id: str, *, current_user: AuthenticatedUser, project_service: ProjectService, ) -> None: try: project_service.get_project(current_user.user_id, project_id) except ProjectNotFoundException as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") from exc except Exception as exc: # pragma: no cover - defensive catch raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorised for this project") from exc @lru_cache(maxsize=1) def get_http_bearer() -> HTTPBearer: """Expose HTTP bearer for reuse (primarily for testing).""" return _http_bearer