Spaces:
Sleeping
Sleeping
| """ | |
| Authentication Router - Google OAuth | |
| Endpoints for Google Sign-In authentication flow. | |
| No more secret keys - users authenticate with their Google account. | |
| """ | |
| from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks | |
| from fastapi.responses import JSONResponse | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy import select | |
| from datetime import datetime | |
| import uuid | |
| import logging | |
| from core.database import get_db | |
| from core.models import User, AuditLog, ClientUser | |
| from core.schemas import ( | |
| CheckRegistrationRequest, | |
| GoogleAuthRequest, | |
| AuthResponse, | |
| UserInfoResponse, | |
| TokenRefreshRequest, | |
| TokenRefreshResponse | |
| ) | |
| from services.auth_service.google_provider import ( | |
| GoogleAuthService, | |
| GoogleUserInfo, | |
| InvalidTokenError as GoogleInvalidTokenError, | |
| ConfigurationError as GoogleConfigError, | |
| get_google_auth_service, | |
| ) | |
| from services.auth_service.jwt_provider import ( | |
| JWTService, | |
| create_access_token, | |
| create_refresh_token, | |
| get_jwt_service, | |
| InvalidTokenError as JWTInvalidTokenError, | |
| ) | |
| from core.dependencies import check_rate_limit, get_current_user | |
| from services.drive_service import DriveService | |
| from services.audit_service import AuditService | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/auth", tags=["auth"]) | |
| drive_service = DriveService() | |
| async def check_registration( | |
| request: CheckRegistrationRequest, | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """ | |
| Check if a temporary user_id has completed registration. | |
| Useful for frontend to check if user needs to sign in. | |
| """ | |
| # Rate Limit: 10 requests per minute per IP | |
| ip = req.client.host | |
| if not await check_rate_limit(db, ip, "/auth/check-registration", 10, 1): | |
| raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") | |
| # Check if this client_user_id has been linked to a server user | |
| query = select(ClientUser).where(ClientUser.client_user_id == request.user_id) | |
| result = await db.execute(query) | |
| client_user = result.scalar_one_or_none() | |
| return {"is_registered": client_user is not None} | |
| def detect_client_type(request: Request) -> str: | |
| """ | |
| Detect client type from User-Agent header. | |
| Browsers get 'web', native apps get 'mobile'. | |
| """ | |
| user_agent = request.headers.get("user-agent", "").lower() | |
| # Browser indicators | |
| browser_keywords = ["mozilla", "chrome", "firefox", "safari", "edge", "opera"] | |
| if any(keyword in user_agent for keyword in browser_keywords): | |
| return "web" | |
| return "mobile" | |
| async def google_auth( | |
| request: GoogleAuthRequest, | |
| req: Request, | |
| background_tasks: BackgroundTasks, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """ | |
| Authenticate with Google ID token. | |
| Supports two client types: | |
| - "web": Sets refresh_token in HttpOnly cookie (secure) | |
| - "mobile": Returns refresh_token in JSON body | |
| Client type is auto-detected from User-Agent if not provided. | |
| """ | |
| response = JSONResponse(content={}) # Placeholder, will be populated later | |
| ip = req.client.host | |
| # Auto-detect client type if not explicitly provided | |
| client_type = request.client_type if request.client_type else detect_client_type(req) | |
| # Rate Limit: 10 attempts per minute per IP | |
| if not await check_rate_limit(db, ip, "/auth/google", 10, 1): | |
| raise HTTPException( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| detail="Too many authentication attempts" | |
| ) | |
| # Verify Google token | |
| try: | |
| google_service = get_google_auth_service() | |
| google_info = google_service.verify_token(request.id_token) | |
| except GoogleConfigError as e: | |
| logger.error(f"Google Auth not configured: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Google authentication is not configured" | |
| ) | |
| except GoogleInvalidTokenError as e: | |
| logger.warning(f"Invalid Google token from {ip}: {e}") | |
| # Log failed attempt | |
| await AuditService.log_event( | |
| db=db, | |
| log_type="server", | |
| action="google_auth", | |
| status="failed", | |
| error_message=str(e), | |
| request=req | |
| ) | |
| await db.commit() | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid Google token. Please try signing in again." | |
| ) | |
| # Check for existing user by email (preserves credits for migrated users) | |
| query = select(User).where(User.email == google_info.email) | |
| result = await db.execute(query) | |
| user = result.scalar_one_or_none() | |
| is_new_user = False | |
| if user: | |
| # Existing user - update Google info | |
| if not user.google_id: | |
| user.google_id = google_info.google_id | |
| logger.info(f"Linked Google account to existing user: {user.email}") | |
| user.name = google_info.name | |
| user.profile_picture = google_info.picture | |
| user.last_used_at = datetime.utcnow() | |
| # Link client_user_id if provided | |
| if request.temp_user_id: | |
| # Check if this client mapping exists | |
| client_query = select(ClientUser).where( | |
| ClientUser.user_id == user.id, # Integer FK comparison | |
| ClientUser.client_user_id == request.temp_user_id | |
| ) | |
| client_result = await db.execute(client_query) | |
| existing_client = client_result.scalar_one_or_none() | |
| if not existing_client: | |
| # Create new client user mapping | |
| client_user = ClientUser( | |
| user_id=user.id, # Integer FK to users.id | |
| client_user_id=request.temp_user_id, | |
| ip_address=ip, # Standardized IP column | |
| last_seen_at=datetime.utcnow() | |
| ) | |
| db.add(client_user) | |
| else: | |
| # Update last seen | |
| existing_client.last_seen_at = datetime.utcnow() | |
| else: | |
| # New user - create account | |
| is_new_user = True | |
| user = User( | |
| user_id="usr_" + str(uuid.uuid4()), | |
| email=google_info.email, | |
| google_id=google_info.google_id, | |
| name=google_info.name, | |
| profile_picture=google_info.picture, | |
| credits=0 | |
| ) | |
| db.add(user) | |
| logger.info(f"New user created via Google: {google_info.email}") | |
| # Create client user mapping if temp_user_id provided | |
| if request.temp_user_id: | |
| client_user = ClientUser( | |
| user_id=user.id, # Integer FK to users.id (will be set after flush) | |
| client_user_id=request.temp_user_id, | |
| ip_address=ip, # Standardized IP column | |
| last_seen_at=datetime.utcnow() | |
| ) | |
| db.add(client_user) | |
| # Log successful auth | |
| await AuditService.log_event( | |
| db=db, | |
| log_type="server", | |
| user_id=user.id, | |
| client_user_id=request.temp_user_id, | |
| action="google_auth", | |
| status="success", | |
| request=req | |
| ) | |
| await db.commit() | |
| # Create our JWT access token and refresh token | |
| access_token = create_access_token(user.user_id, user.email, user.token_version) | |
| refresh_token = create_refresh_token(user.user_id, user.email, user.token_version) | |
| # Sync DB to Drive (Async) | |
| from services.backup_service import get_backup_service | |
| backup_service = get_backup_service() | |
| background_tasks.add_task(backup_service.backup_async) | |
| # Prepare response data | |
| response_data = { | |
| "success": True, | |
| "access_token": access_token, | |
| "user_id": user.user_id, | |
| "email": user.email, | |
| "name": user.name, | |
| "credits": user.credits, | |
| "is_new_user": is_new_user | |
| } | |
| # Handle token delivery based on client type | |
| if client_type == "web": | |
| # Web: Set HttpOnly cookie for refresh token | |
| response = JSONResponse(content=response_data) | |
| # Cookie settings for production | |
| import os | |
| is_production = os.getenv("ENVIRONMENT", "production") == "production" | |
| response.set_cookie( | |
| key="refresh_token", | |
| value=refresh_token, | |
| httponly=True, | |
| secure=is_production, # True in production (HTTPS), False locally (HTTP) | |
| samesite="none" if is_production else "lax", # 'none' for cross-origin in production | |
| max_age=7 * 24 * 60 * 60, # 7 days | |
| domain=None # Let browser set domain automatically | |
| ) | |
| logger.info(f"Set refresh_token cookie for web client (production={is_production})") | |
| else: | |
| # Mobile: Return refresh token in body | |
| response_data["refresh_token"] = refresh_token | |
| response = JSONResponse(content=response_data) | |
| logger.info(f"Returned refresh_token in body for mobile client") | |
| return response | |
| async def get_current_user_info( | |
| user: User = Depends(get_current_user) | |
| ): | |
| """ | |
| Get current authenticated user info. | |
| Requires Authorization: Bearer <token> header. | |
| """ | |
| return UserInfoResponse( | |
| user_id=user.user_id, | |
| email=user.email, | |
| name=user.name, | |
| credits=user.credits, | |
| profile_picture=user.profile_picture | |
| ) | |
| async def refresh_token( | |
| request: TokenRefreshRequest, | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """ | |
| Refresh an access token. | |
| Use this when the current token is about to expire | |
| (or has recently expired) to get a new one without | |
| requiring the user to sign in again. | |
| Validates that the token_version is still valid before refreshing. | |
| """ | |
| ip = req.client.host | |
| # Rate Limit: 20 refreshes per minute per IP (increased for proactive refresh on page load) | |
| if not await check_rate_limit(db, ip, "/auth/refresh", 20, 1): | |
| raise HTTPException( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| detail="Too many refresh attempts" | |
| ) | |
| try: | |
| jwt_service = get_jwt_service() | |
| # Get token from body or cookie | |
| token_to_refresh = request.token | |
| using_cookie = False | |
| if not token_to_refresh: | |
| token_to_refresh = req.cookies.get("refresh_token") | |
| using_cookie = True | |
| if not token_to_refresh: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Refresh token missing" | |
| ) | |
| # Decode the token (without verifying expiry) to get user info | |
| import jwt as pyjwt | |
| payload = pyjwt.decode( | |
| token_to_refresh, | |
| jwt_service.secret_key, | |
| algorithms=[jwt_service.algorithm], | |
| options={"verify_exp": False} | |
| ) | |
| user_id = payload.get("sub") | |
| token_version = payload.get("tv", 1) | |
| token_type = payload.get("type", "access") | |
| if not user_id: | |
| raise JWTInvalidTokenError("Token missing required claims") | |
| # Verify it's a refresh token | |
| if token_type != "refresh": | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid token type. Expected refresh token." | |
| ) | |
| # Check if user exists and token version is still valid | |
| query = select(User).where(User.user_id == user_id, User.is_active == True) | |
| result = await db.execute(query) | |
| user = result.scalar_one_or_none() | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="User not found or inactive" | |
| ) | |
| # Validate token version | |
| if token_version < user.token_version: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token has been invalidated. Please sign in again." | |
| ) | |
| # Create new access token | |
| new_access_token = create_access_token(user.user_id, user.email, user.token_version) | |
| # ROTATION: Issue new refresh token | |
| new_refresh_token = create_refresh_token(user.user_id, user.email, user.token_version) | |
| response_data = { | |
| "success": True, | |
| "access_token": new_access_token | |
| } | |
| if using_cookie: | |
| # If came from cookie, rotate cookie | |
| response = JSONResponse(content=response_data) | |
| # Cookie settings for production | |
| import os | |
| is_production = os.getenv("ENVIRONMENT", "production") == "production" | |
| response.set_cookie( | |
| key="refresh_token", | |
| value=new_refresh_token, | |
| httponly=True, | |
| secure=is_production, # True in production (HTTPS), False locally (HTTP) | |
| samesite="none" if is_production else "lax", # 'none' for cross-origin in production | |
| max_age=7 * 24 * 60 * 60, | |
| domain=None # Let browser set domain automatically | |
| ) | |
| logger.info(f"Rotated refresh_token cookie (production={is_production})") | |
| return response | |
| else: | |
| # If came from body, return in body | |
| response_data["refresh_token"] = new_refresh_token | |
| return TokenRefreshResponse(**response_data) | |
| except JWTInvalidTokenError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=f"Cannot refresh token: {str(e)}" | |
| ) | |
| async def logout( | |
| req: Request, | |
| background_tasks: BackgroundTasks, | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """ | |
| Logout current user. | |
| Increments the user's token_version which invalidates ALL existing | |
| tokens for this user. This provides instant logout across all devices. | |
| """ | |
| ip = req.client.host | |
| # Increment token version to invalidate all existing tokens | |
| user.token_version += 1 | |
| logger.info(f"User {user.user_id} logged out. Token version incremented to {user.token_version}") | |
| # Log logout | |
| await AuditService.log_event( | |
| db=db, | |
| log_type="server", | |
| user_id=user.id, | |
| action="logout", | |
| status="success", | |
| request=req | |
| ) | |
| await db.commit() | |
| # Sync DB to Drive (Async) | |
| from services.backup_service import get_backup_service | |
| backup_service = get_backup_service() | |
| background_tasks.add_task(backup_service.backup_async) | |
| response = JSONResponse(content={"success": True, "message": "Logged out successfully. All sessions invalidated."}) | |
| response.delete_cookie(key="refresh_token") | |
| return response | |