File size: 2,571 Bytes
783a952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Shared FastAPI dependencies for ML module routes."""
from __future__ import annotations

from functools import lru_cache
from typing import Any, Dict, Optional

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, Field

from ml_module.core.auth import SupabaseAuthError, get_supabase_verifier
from ml_module.core.exceptions import ProjectNotFoundException
from ml_module.services.project_service import ProjectService


class AuthenticatedUser(BaseModel):
    """Represents the authenticated Supabase user extracted from the JWT."""

    user_id: str = Field(..., description="Supabase user identifier")
    email: Optional[str] = Field(default=None)
    role: Optional[str] = Field(default=None)
    claims: Dict[str, Any] = Field(default_factory=dict)


_http_bearer = HTTPBearer(auto_error=False)


async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(_http_bearer),
) -> AuthenticatedUser:
    if credentials is None:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization header missing")

    token = credentials.credentials
    verifier = get_supabase_verifier()
    try:
        claims = verifier.decode(token)
    except SupabaseAuthError as exc:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc

    user_id = str(claims.get("sub") or claims.get("user_id"))
    if not user_id:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Supabase token missing subject")

    app_metadata = claims.get("app_metadata") or {}
    role = app_metadata.get("role") or claims.get("role")

    return AuthenticatedUser(
        user_id=user_id,
        email=claims.get("email"),
        role=role,
        claims=claims,
    )


def ensure_project_access(
    project_id: str,
    *,
    current_user: AuthenticatedUser,
    project_service: ProjectService,
) -> None:
    try:
        project_service.get_project(current_user.user_id, project_id)
    except ProjectNotFoundException as exc:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") from exc
    except Exception as exc:  # pragma: no cover - defensive catch
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorised for this project") from exc


@lru_cache(maxsize=1)
def get_http_bearer() -> HTTPBearer:
    """Expose HTTP bearer for reuse (primarily for testing)."""
    return _http_bearer