apigateway / routers /auth.py
jebin2's picture
ref
a42ab7e
"""
Authentication Router - Google OAuth
Endpoints for Google Sign-In authentication flow.
No more secret keys - users authenticate with their Google account.
"""
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from datetime import datetime
import uuid
import logging
from core.database import get_db
from core.models import User, AuditLog, ClientUser
from core.schemas import (
CheckRegistrationRequest,
GoogleAuthRequest,
AuthResponse,
UserInfoResponse,
TokenRefreshRequest,
TokenRefreshResponse
)
from services.auth_service.google_provider import (
GoogleAuthService,
GoogleUserInfo,
InvalidTokenError as GoogleInvalidTokenError,
ConfigurationError as GoogleConfigError,
get_google_auth_service,
)
from services.auth_service.jwt_provider import (
JWTService,
create_access_token,
create_refresh_token,
get_jwt_service,
InvalidTokenError as JWTInvalidTokenError,
)
from core.dependencies import check_rate_limit, get_current_user
from services.drive_service import DriveService
from services.audit_service import AuditService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
drive_service = DriveService()
@router.post("/check-registration")
async def check_registration(
request: CheckRegistrationRequest,
req: Request,
db: AsyncSession = Depends(get_db)
):
"""
Check if a temporary user_id has completed registration.
Useful for frontend to check if user needs to sign in.
"""
# Rate Limit: 10 requests per minute per IP
ip = req.client.host
if not await check_rate_limit(db, ip, "/auth/check-registration", 10, 1):
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests")
# Check if this client_user_id has been linked to a server user
query = select(ClientUser).where(ClientUser.client_user_id == request.user_id)
result = await db.execute(query)
client_user = result.scalar_one_or_none()
return {"is_registered": client_user is not None}
def detect_client_type(request: Request) -> str:
"""
Detect client type from User-Agent header.
Browsers get 'web', native apps get 'mobile'.
"""
user_agent = request.headers.get("user-agent", "").lower()
# Browser indicators
browser_keywords = ["mozilla", "chrome", "firefox", "safari", "edge", "opera"]
if any(keyword in user_agent for keyword in browser_keywords):
return "web"
return "mobile"
@router.post("/google", response_model=AuthResponse)
async def google_auth(
request: GoogleAuthRequest,
req: Request,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""
Authenticate with Google ID token.
Supports two client types:
- "web": Sets refresh_token in HttpOnly cookie (secure)
- "mobile": Returns refresh_token in JSON body
Client type is auto-detected from User-Agent if not provided.
"""
response = JSONResponse(content={}) # Placeholder, will be populated later
ip = req.client.host
# Auto-detect client type if not explicitly provided
client_type = request.client_type if request.client_type else detect_client_type(req)
# Rate Limit: 10 attempts per minute per IP
if not await check_rate_limit(db, ip, "/auth/google", 10, 1):
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many authentication attempts"
)
# Verify Google token
try:
google_service = get_google_auth_service()
google_info = google_service.verify_token(request.id_token)
except GoogleConfigError as e:
logger.error(f"Google Auth not configured: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Google authentication is not configured"
)
except GoogleInvalidTokenError as e:
logger.warning(f"Invalid Google token from {ip}: {e}")
# Log failed attempt
await AuditService.log_event(
db=db,
log_type="server",
action="google_auth",
status="failed",
error_message=str(e),
request=req
)
await db.commit()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Google token. Please try signing in again."
)
# Check for existing user by email (preserves credits for migrated users)
query = select(User).where(User.email == google_info.email)
result = await db.execute(query)
user = result.scalar_one_or_none()
is_new_user = False
if user:
# Existing user - update Google info
if not user.google_id:
user.google_id = google_info.google_id
logger.info(f"Linked Google account to existing user: {user.email}")
user.name = google_info.name
user.profile_picture = google_info.picture
user.last_used_at = datetime.utcnow()
# Link client_user_id if provided
if request.temp_user_id:
# Check if this client mapping exists
client_query = select(ClientUser).where(
ClientUser.user_id == user.id, # Integer FK comparison
ClientUser.client_user_id == request.temp_user_id
)
client_result = await db.execute(client_query)
existing_client = client_result.scalar_one_or_none()
if not existing_client:
# Create new client user mapping
client_user = ClientUser(
user_id=user.id, # Integer FK to users.id
client_user_id=request.temp_user_id,
ip_address=ip, # Standardized IP column
last_seen_at=datetime.utcnow()
)
db.add(client_user)
else:
# Update last seen
existing_client.last_seen_at = datetime.utcnow()
else:
# New user - create account
is_new_user = True
user = User(
user_id="usr_" + str(uuid.uuid4()),
email=google_info.email,
google_id=google_info.google_id,
name=google_info.name,
profile_picture=google_info.picture,
credits=0
)
db.add(user)
logger.info(f"New user created via Google: {google_info.email}")
# Create client user mapping if temp_user_id provided
if request.temp_user_id:
client_user = ClientUser(
user_id=user.id, # Integer FK to users.id (will be set after flush)
client_user_id=request.temp_user_id,
ip_address=ip, # Standardized IP column
last_seen_at=datetime.utcnow()
)
db.add(client_user)
# Log successful auth
await AuditService.log_event(
db=db,
log_type="server",
user_id=user.id,
client_user_id=request.temp_user_id,
action="google_auth",
status="success",
request=req
)
await db.commit()
# Create our JWT access token and refresh token
access_token = create_access_token(user.user_id, user.email, user.token_version)
refresh_token = create_refresh_token(user.user_id, user.email, user.token_version)
# Sync DB to Drive (Async)
from services.backup_service import get_backup_service
backup_service = get_backup_service()
background_tasks.add_task(backup_service.backup_async)
# Prepare response data
response_data = {
"success": True,
"access_token": access_token,
"user_id": user.user_id,
"email": user.email,
"name": user.name,
"credits": user.credits,
"is_new_user": is_new_user
}
# Handle token delivery based on client type
if client_type == "web":
# Web: Set HttpOnly cookie for refresh token
response = JSONResponse(content=response_data)
# Cookie settings for production
import os
is_production = os.getenv("ENVIRONMENT", "production") == "production"
response.set_cookie(
key="refresh_token",
value=refresh_token,
httponly=True,
secure=is_production, # True in production (HTTPS), False locally (HTTP)
samesite="none" if is_production else "lax", # 'none' for cross-origin in production
max_age=7 * 24 * 60 * 60, # 7 days
domain=None # Let browser set domain automatically
)
logger.info(f"Set refresh_token cookie for web client (production={is_production})")
else:
# Mobile: Return refresh token in body
response_data["refresh_token"] = refresh_token
response = JSONResponse(content=response_data)
logger.info(f"Returned refresh_token in body for mobile client")
return response
@router.get("/me", response_model=UserInfoResponse)
async def get_current_user_info(
user: User = Depends(get_current_user)
):
"""
Get current authenticated user info.
Requires Authorization: Bearer <token> header.
"""
return UserInfoResponse(
user_id=user.user_id,
email=user.email,
name=user.name,
credits=user.credits,
profile_picture=user.profile_picture
)
@router.post("/refresh", response_model=TokenRefreshResponse)
async def refresh_token(
request: TokenRefreshRequest,
req: Request,
db: AsyncSession = Depends(get_db)
):
"""
Refresh an access token.
Use this when the current token is about to expire
(or has recently expired) to get a new one without
requiring the user to sign in again.
Validates that the token_version is still valid before refreshing.
"""
ip = req.client.host
# Rate Limit: 20 refreshes per minute per IP (increased for proactive refresh on page load)
if not await check_rate_limit(db, ip, "/auth/refresh", 20, 1):
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many refresh attempts"
)
try:
jwt_service = get_jwt_service()
# Get token from body or cookie
token_to_refresh = request.token
using_cookie = False
if not token_to_refresh:
token_to_refresh = req.cookies.get("refresh_token")
using_cookie = True
if not token_to_refresh:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token missing"
)
# Decode the token (without verifying expiry) to get user info
import jwt as pyjwt
payload = pyjwt.decode(
token_to_refresh,
jwt_service.secret_key,
algorithms=[jwt_service.algorithm],
options={"verify_exp": False}
)
user_id = payload.get("sub")
token_version = payload.get("tv", 1)
token_type = payload.get("type", "access")
if not user_id:
raise JWTInvalidTokenError("Token missing required claims")
# Verify it's a refresh token
if token_type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type. Expected refresh token."
)
# Check if user exists and token version is still valid
query = select(User).where(User.user_id == user_id, User.is_active == True)
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
)
# Validate token version
if token_version < user.token_version:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been invalidated. Please sign in again."
)
# Create new access token
new_access_token = create_access_token(user.user_id, user.email, user.token_version)
# ROTATION: Issue new refresh token
new_refresh_token = create_refresh_token(user.user_id, user.email, user.token_version)
response_data = {
"success": True,
"access_token": new_access_token
}
if using_cookie:
# If came from cookie, rotate cookie
response = JSONResponse(content=response_data)
# Cookie settings for production
import os
is_production = os.getenv("ENVIRONMENT", "production") == "production"
response.set_cookie(
key="refresh_token",
value=new_refresh_token,
httponly=True,
secure=is_production, # True in production (HTTPS), False locally (HTTP)
samesite="none" if is_production else "lax", # 'none' for cross-origin in production
max_age=7 * 24 * 60 * 60,
domain=None # Let browser set domain automatically
)
logger.info(f"Rotated refresh_token cookie (production={is_production})")
return response
else:
# If came from body, return in body
response_data["refresh_token"] = new_refresh_token
return TokenRefreshResponse(**response_data)
except JWTInvalidTokenError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Cannot refresh token: {str(e)}"
)
@router.post("/logout")
async def logout(
req: Request,
background_tasks: BackgroundTasks,
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
Logout current user.
Increments the user's token_version which invalidates ALL existing
tokens for this user. This provides instant logout across all devices.
"""
ip = req.client.host
# Increment token version to invalidate all existing tokens
user.token_version += 1
logger.info(f"User {user.user_id} logged out. Token version incremented to {user.token_version}")
# Log logout
await AuditService.log_event(
db=db,
log_type="server",
user_id=user.id,
action="logout",
status="success",
request=req
)
await db.commit()
# Sync DB to Drive (Async)
from services.backup_service import get_backup_service
backup_service = get_backup_service()
background_tasks.add_task(backup_service.backup_async)
response = JSONResponse(content={"success": True, "message": "Logged out successfully. All sessions invalidated."})
response.delete_cookie(key="refresh_token")
return response