Spaces:
Running
Running
| import logging | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Tuple | |
| import ipaddress | |
| import httpx | |
| from fastapi import Request, Depends, HTTPException, status | |
| from sqlalchemy import select, and_ | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from core.database import get_db | |
| from core.models import User, RateLimit | |
| from services.jwt_service import ( | |
| verify_access_token, | |
| TokenExpiredError, | |
| InvalidTokenError, | |
| JWTError | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Geolocation API settings | |
| GEOLOCATION_API_URL = "http://ip-api.com/json/{ip}?fields=status,country,regionName" | |
| GEOLOCATION_TIMEOUT = 2.0 # seconds | |
| async def check_rate_limit( | |
| db: AsyncSession, | |
| identifier: str, | |
| endpoint: str, | |
| limit: int, | |
| window_minutes: int | |
| ) -> bool: | |
| """ | |
| Check if request is within rate limits. | |
| Returns True if allowed, False if limit exceeded. | |
| """ | |
| now = datetime.utcnow() | |
| window_start = now - timedelta(minutes=window_minutes) | |
| # Check existing limit | |
| query = select(RateLimit).where( | |
| and_( | |
| RateLimit.identifier == identifier, | |
| RateLimit.endpoint == endpoint, | |
| RateLimit.window_start >= window_start | |
| ) | |
| ) | |
| result = await db.execute(query) | |
| rate_limit = result.scalar_one_or_none() | |
| if rate_limit: | |
| if rate_limit.attempts >= limit: | |
| return False | |
| # Increment attempts | |
| rate_limit.attempts += 1 | |
| await db.commit() | |
| return True | |
| else: | |
| # Create new rate limit record | |
| new_limit = RateLimit( | |
| identifier=identifier, | |
| endpoint=endpoint, | |
| attempts=1, | |
| window_start=now, | |
| expires_at=now + timedelta(minutes=window_minutes) | |
| ) | |
| db.add(new_limit) | |
| await db.commit() | |
| return True | |
| async def get_current_user( | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ) -> User: | |
| """ | |
| Extract and verify JWT from Authorization header. | |
| Returns the authenticated user. | |
| Also validates token_version to support instant logout/invalidation. | |
| Usage: | |
| @router.get("/protected") | |
| async def protected_route(user: User = Depends(get_current_user)): | |
| return {"user_id": user.user_id} | |
| """ | |
| auth_header = req.headers.get("Authorization") | |
| if not auth_header: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Missing Authorization header", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| if not auth_header.startswith("Bearer "): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid Authorization header format. Use: Bearer <token>", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| token = auth_header.split(" ", 1)[1] | |
| try: | |
| payload = verify_access_token(token) | |
| except TokenExpiredError: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token has expired. Please sign in again.", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| except InvalidTokenError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=f"Invalid token: {str(e)}", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| except JWTError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=f"Authentication error: {str(e)}", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| # Get user from DB | |
| query = select(User).where( | |
| User.user_id == payload.user_id, | |
| User.is_active == True | |
| ) | |
| result = await db.execute(query) | |
| user = result.scalar_one_or_none() | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="User not found or inactive" | |
| ) | |
| # Validate token version - if user's version is higher, token is invalidated | |
| if payload.token_version < user.token_version: | |
| logger.info(f"Token invalidated for user {user.user_id}: token_version {payload.token_version} < {user.token_version}") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token has been invalidated. Please sign in again.", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| return user | |
| async def verify_credits( | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db) | |
| ) -> User: | |
| """ | |
| Verify user has credits and deduct one. | |
| This dependency first authenticates the user via JWT, | |
| then checks and deducts credits. | |
| Usage: | |
| @router.post("/api-endpoint") | |
| async def api_endpoint(user: User = Depends(verify_credits)): | |
| # User is authenticated and has 1 credit deducted | |
| return {"credits_remaining": user.credits} | |
| """ | |
| if user.credits <= 0: | |
| raise HTTPException( | |
| status_code=status.HTTP_402_PAYMENT_REQUIRED, | |
| detail="Insufficient credits. Please purchase more credits." | |
| ) | |
| # Deduct credit | |
| user.credits -= 1 | |
| user.last_used_at = datetime.utcnow() | |
| await db.commit() | |
| logger.debug(f"Deducted 1 credit from user {user.user_id}. Remaining: {user.credits}") | |
| return user | |
| async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]: | |
| """ | |
| Get country and region for an IP address using ip-api.com. | |
| Args: | |
| ip_address: IPv4 or IPv6 address | |
| Returns: | |
| Tuple of (country, region) or (None, None) if lookup fails | |
| """ | |
| if not ip_address: | |
| return None, None | |
| # Skip geolocation for localhost/private IPs | |
| if ip_address in ("127.0.0.1", "::1", "localhost") or ip_address.startswith(("192.168.", "10.", "172.")): | |
| return None, None | |
| try: | |
| async with httpx.AsyncClient(timeout=GEOLOCATION_TIMEOUT) as client: | |
| response = await client.get(GEOLOCATION_API_URL.format(ip=ip_address)) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if data.get("status") == "success": | |
| return data.get("country"), data.get("regionName") | |
| except Exception as e: | |
| logger.warning(f"Geolocation lookup failed for {ip_address}: {e}") | |
| return None, None | |