Spaces:
Sleeping
Sleeping
| """ | |
| Rate Limiting Dependencies | |
| Functions for checking and enforcing rate limits on API endpoints. | |
| """ | |
| import logging | |
| from datetime import datetime, timedelta | |
| from sqlalchemy import select, and_ | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from core.models import RateLimit | |
| logger = logging.getLogger(__name__) | |
| async def check_rate_limit( | |
| db: AsyncSession, | |
| identifier: str, | |
| endpoint: str, | |
| limit: int, | |
| window_minutes: int | |
| ) -> bool: | |
| """ | |
| Check if request is within rate limits. | |
| Returns True if allowed, False if limit exceeded. | |
| """ | |
| now = datetime.utcnow() | |
| window_start = now - timedelta(minutes=window_minutes) | |
| # Check existing limit (get most recent if multiple exist) | |
| query = select(RateLimit).where( | |
| and_( | |
| RateLimit.identifier == identifier, | |
| RateLimit.endpoint == endpoint, | |
| RateLimit.window_start >= window_start | |
| ) | |
| ).order_by(RateLimit.window_start.desc()) | |
| result = await db.execute(query) | |
| rate_limit = result.scalars().first() | |
| if rate_limit: | |
| if rate_limit.attempts >= limit: | |
| return False | |
| # Increment attempts | |
| rate_limit.attempts += 1 | |
| await db.commit() | |
| return True | |
| else: | |
| # Create new rate limit record | |
| new_limit = RateLimit( | |
| identifier=identifier, | |
| endpoint=endpoint, | |
| attempts=1, | |
| window_start=now, | |
| expires_at=now + timedelta(minutes=window_minutes) | |
| ) | |
| db.add(new_limit) | |
| await db.commit() | |
| return True | |