apigateway / dependencies.py
jebin2's picture
tokrn version
19e4a8c
raw
history blame
6.54 kB
import logging
from datetime import datetime, timedelta
from typing import Optional, Tuple
import ipaddress
import httpx
from fastapi import Request, Depends, HTTPException, status
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from core.database import get_db
from core.models import User, RateLimit
from services.jwt_service import (
verify_access_token,
TokenExpiredError,
InvalidTokenError,
JWTError
)
logger = logging.getLogger(__name__)
# Geolocation API settings
GEOLOCATION_API_URL = "http://ip-api.com/json/{ip}?fields=status,country,regionName"
GEOLOCATION_TIMEOUT = 2.0 # seconds
async def check_rate_limit(
db: AsyncSession,
identifier: str,
endpoint: str,
limit: int,
window_minutes: int
) -> bool:
"""
Check if request is within rate limits.
Returns True if allowed, False if limit exceeded.
"""
now = datetime.utcnow()
window_start = now - timedelta(minutes=window_minutes)
# Check existing limit
query = select(RateLimit).where(
and_(
RateLimit.identifier == identifier,
RateLimit.endpoint == endpoint,
RateLimit.window_start >= window_start
)
)
result = await db.execute(query)
rate_limit = result.scalar_one_or_none()
if rate_limit:
if rate_limit.attempts >= limit:
return False
# Increment attempts
rate_limit.attempts += 1
await db.commit()
return True
else:
# Create new rate limit record
new_limit = RateLimit(
identifier=identifier,
endpoint=endpoint,
attempts=1,
window_start=now,
expires_at=now + timedelta(minutes=window_minutes)
)
db.add(new_limit)
await db.commit()
return True
async def get_current_user(
req: Request,
db: AsyncSession = Depends(get_db)
) -> User:
"""
Extract and verify JWT from Authorization header.
Returns the authenticated user.
Also validates token_version to support instant logout/invalidation.
Usage:
@router.get("/protected")
async def protected_route(user: User = Depends(get_current_user)):
return {"user_id": user.user_id}
"""
auth_header = req.headers.get("Authorization")
if not auth_header:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
headers={"WWW-Authenticate": "Bearer"}
)
if not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Authorization header format. Use: Bearer <token>",
headers={"WWW-Authenticate": "Bearer"}
)
token = auth_header.split(" ", 1)[1]
try:
payload = verify_access_token(token)
except TokenExpiredError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired. Please sign in again.",
headers={"WWW-Authenticate": "Bearer"}
)
except InvalidTokenError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"}
)
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication error: {str(e)}",
headers={"WWW-Authenticate": "Bearer"}
)
# Get user from DB
query = select(User).where(
User.user_id == payload.user_id,
User.is_active == True
)
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
)
# Validate token version - if user's version is higher, token is invalidated
if payload.token_version < user.token_version:
logger.info(f"Token invalidated for user {user.user_id}: token_version {payload.token_version} < {user.token_version}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been invalidated. Please sign in again.",
headers={"WWW-Authenticate": "Bearer"}
)
return user
async def verify_credits(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> User:
"""
Verify user has credits and deduct one.
This dependency first authenticates the user via JWT,
then checks and deducts credits.
Usage:
@router.post("/api-endpoint")
async def api_endpoint(user: User = Depends(verify_credits)):
# User is authenticated and has 1 credit deducted
return {"credits_remaining": user.credits}
"""
if user.credits <= 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail="Insufficient credits. Please purchase more credits."
)
# Deduct credit
user.credits -= 1
user.last_used_at = datetime.utcnow()
await db.commit()
logger.debug(f"Deducted 1 credit from user {user.user_id}. Remaining: {user.credits}")
return user
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
"""
Get country and region for an IP address using ip-api.com.
Args:
ip_address: IPv4 or IPv6 address
Returns:
Tuple of (country, region) or (None, None) if lookup fails
"""
if not ip_address:
return None, None
# Skip geolocation for localhost/private IPs
if ip_address in ("127.0.0.1", "::1", "localhost") or ip_address.startswith(("192.168.", "10.", "172.")):
return None, None
try:
async with httpx.AsyncClient(timeout=GEOLOCATION_TIMEOUT) as client:
response = await client.get(GEOLOCATION_API_URL.format(ip=ip_address))
if response.status_code == 200:
data = response.json()
if data.get("status") == "success":
return data.get("country"), data.get("regionName")
except Exception as e:
logger.warning(f"Geolocation lookup failed for {ip_address}: {e}")
return None, None