Spaces:
Running
Running
ref
Browse files- core/dependencies/__init__.py +14 -0
- dependencies.py → core/dependencies/auth.py +8 -139
- core/dependencies/rate_limit.py +60 -0
- core/utils/__init__.py +8 -0
- core/utils/geolocation.py +44 -0
- routers/auth.py +1 -1
- routers/blink.py +1 -1
- tests/test_auth_router.py +3 -3
- tests/test_dependencies.py +11 -11
- tests/test_gemini_router.py +9 -9
- tests/test_payments_router.py +5 -5
- tests/test_rate_limiting.py +13 -13
- tests/test_token_expiry_integration.py +1 -1
core/dependencies/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core Dependencies
|
| 3 |
+
|
| 4 |
+
FastAPI dependencies for authentication, authorization, and rate limiting.
|
| 5 |
+
Re-exports all dependency functions for backward compatibility.
|
| 6 |
+
"""
|
| 7 |
+
from .auth import get_current_user, get_optional_user
|
| 8 |
+
from .rate_limit import check_rate_limit
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"get_current_user",
|
| 12 |
+
"get_optional_user",
|
| 13 |
+
"check_rate_limit",
|
| 14 |
+
]
|
dependencies.py → core/dependencies/auth.py
RENAMED
|
@@ -1,14 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
-
from
|
| 3 |
-
from typing import Optional, Tuple, Union
|
| 4 |
-
import ipaddress
|
| 5 |
-
import httpx
|
| 6 |
from fastapi import Request, Depends, HTTPException, status
|
| 7 |
-
from sqlalchemy import select
|
| 8 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 9 |
|
| 10 |
from core.database import get_db
|
| 11 |
-
from core.models import User
|
| 12 |
from services.auth_service.jwt_provider import (
|
| 13 |
verify_access_token,
|
| 14 |
TokenExpiredError,
|
|
@@ -18,56 +20,6 @@ from services.auth_service.jwt_provider import (
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
-
# Geolocation API settings
|
| 22 |
-
GEOLOCATION_API_URL = "http://ip-api.com/json/{ip}?fields=status,country,regionName"
|
| 23 |
-
GEOLOCATION_TIMEOUT = 2.0 # seconds
|
| 24 |
-
|
| 25 |
-
async def check_rate_limit(
|
| 26 |
-
db: AsyncSession,
|
| 27 |
-
identifier: str,
|
| 28 |
-
endpoint: str,
|
| 29 |
-
limit: int,
|
| 30 |
-
window_minutes: int
|
| 31 |
-
) -> bool:
|
| 32 |
-
"""
|
| 33 |
-
Check if request is within rate limits.
|
| 34 |
-
Returns True if allowed, False if limit exceeded.
|
| 35 |
-
"""
|
| 36 |
-
now = datetime.utcnow()
|
| 37 |
-
window_start = now - timedelta(minutes=window_minutes)
|
| 38 |
-
|
| 39 |
-
# Check existing limit (get most recent if multiple exist)
|
| 40 |
-
query = select(RateLimit).where(
|
| 41 |
-
and_(
|
| 42 |
-
RateLimit.identifier == identifier,
|
| 43 |
-
RateLimit.endpoint == endpoint,
|
| 44 |
-
RateLimit.window_start >= window_start
|
| 45 |
-
)
|
| 46 |
-
).order_by(RateLimit.window_start.desc())
|
| 47 |
-
result = await db.execute(query)
|
| 48 |
-
rate_limit = result.scalars().first()
|
| 49 |
-
|
| 50 |
-
if rate_limit:
|
| 51 |
-
if rate_limit.attempts >= limit:
|
| 52 |
-
return False
|
| 53 |
-
|
| 54 |
-
# Increment attempts
|
| 55 |
-
rate_limit.attempts += 1
|
| 56 |
-
await db.commit()
|
| 57 |
-
return True
|
| 58 |
-
else:
|
| 59 |
-
# Create new rate limit record
|
| 60 |
-
new_limit = RateLimit(
|
| 61 |
-
identifier=identifier,
|
| 62 |
-
endpoint=endpoint,
|
| 63 |
-
attempts=1,
|
| 64 |
-
window_start=now,
|
| 65 |
-
expires_at=now + timedelta(minutes=window_minutes)
|
| 66 |
-
)
|
| 67 |
-
db.add(new_limit)
|
| 68 |
-
await db.commit()
|
| 69 |
-
return True
|
| 70 |
-
|
| 71 |
|
| 72 |
async def get_current_user(
|
| 73 |
req: Request,
|
|
@@ -205,86 +157,3 @@ async def get_optional_user(
|
|
| 205 |
return None
|
| 206 |
|
| 207 |
return user
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
# =============================================================================
|
| 211 |
-
# LEGACY CREDIT VERIFICATION - DEPRECATED
|
| 212 |
-
# =============================================================================
|
| 213 |
-
# These functions are NO LONGER USED.
|
| 214 |
-
# Credit operations are now handled automatically by CreditMiddleware.
|
| 215 |
-
# Keeping these commented for reference during migration.
|
| 216 |
-
# =============================================================================
|
| 217 |
-
|
| 218 |
-
# async def verify_credits(
|
| 219 |
-
# user: User = Depends(get_current_user),
|
| 220 |
-
# db: AsyncSession = Depends(get_db)
|
| 221 |
-
# ) -> User:
|
| 222 |
-
# """
|
| 223 |
-
# DEPRECATED: Credit verification is now handled by CreditMiddleware.
|
| 224 |
-
# This function manually deducted credits, which bypassed the transaction system.
|
| 225 |
-
# """
|
| 226 |
-
# if user.credits <= 0:
|
| 227 |
-
# raise HTTPException(
|
| 228 |
-
# status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
| 229 |
-
# detail="Insufficient credits. Please purchase more credits."
|
| 230 |
-
# )
|
| 231 |
-
#
|
| 232 |
-
# user.credits -= 1
|
| 233 |
-
# user.last_used_at = datetime.utcnow()
|
| 234 |
-
# await db.commit()
|
| 235 |
-
#
|
| 236 |
-
# logger.debug(f"Deducted 1 credit from user {user.user_id}. Remaining: {user.credits}")
|
| 237 |
-
# return user
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
# async def verify_video_credits(
|
| 241 |
-
# user: User = Depends(get_current_user),
|
| 242 |
-
# db: AsyncSession = Depends(get_db)
|
| 243 |
-
# ) -> User:
|
| 244 |
-
# """
|
| 245 |
-
# DEPRECATED: Video credit verification is now handled by CreditMiddleware.
|
| 246 |
-
# This function manually deducted credits, which bypassed the transaction system.
|
| 247 |
-
# """
|
| 248 |
-
# cost = 10
|
| 249 |
-
# if user.credits < cost:
|
| 250 |
-
# raise HTTPException(
|
| 251 |
-
# status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
| 252 |
-
# detail=f"Insufficient credits. Video generation requires {cost} credits."
|
| 253 |
-
# )
|
| 254 |
-
#
|
| 255 |
-
# user.credits -= cost
|
| 256 |
-
# user.last_used_at = datetime.utcnow()
|
| 257 |
-
# await db.commit()
|
| 258 |
-
#
|
| 259 |
-
# logger.debug(f"Deducted {cost} credits from user {user.user_id} for video generation. Remaining: {user.credits}")
|
| 260 |
-
# return user
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 264 |
-
"""
|
| 265 |
-
Get country and region for an IP address using ip-api.com.
|
| 266 |
-
|
| 267 |
-
Args:
|
| 268 |
-
ip_address: IPv4 or IPv6 address
|
| 269 |
-
|
| 270 |
-
Returns:
|
| 271 |
-
Tuple of (country, region) or (None, None) if lookup fails
|
| 272 |
-
"""
|
| 273 |
-
if not ip_address:
|
| 274 |
-
return None, None
|
| 275 |
-
|
| 276 |
-
# Skip geolocation for localhost/private IPs
|
| 277 |
-
if ip_address in ("127.0.0.1", "::1", "localhost") or ip_address.startswith(("192.168.", "10.", "172.")):
|
| 278 |
-
return None, None
|
| 279 |
-
|
| 280 |
-
try:
|
| 281 |
-
async with httpx.AsyncClient(timeout=GEOLOCATION_TIMEOUT) as client:
|
| 282 |
-
response = await client.get(GEOLOCATION_API_URL.format(ip=ip_address))
|
| 283 |
-
if response.status_code == 200:
|
| 284 |
-
data = response.json()
|
| 285 |
-
if data.get("status") == "success":
|
| 286 |
-
return data.get("country"), data.get("regionName")
|
| 287 |
-
except Exception as e:
|
| 288 |
-
logger.warning(f"Geolocation lookup failed for {ip_address}: {e}")
|
| 289 |
-
|
| 290 |
-
return None, None
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication Dependencies
|
| 3 |
+
|
| 4 |
+
FastAPI dependencies for user authentication and authorization.
|
| 5 |
+
"""
|
| 6 |
import logging
|
| 7 |
+
from typing import Optional
|
|
|
|
|
|
|
|
|
|
| 8 |
from fastapi import Request, Depends, HTTPException, status
|
| 9 |
+
from sqlalchemy import select
|
| 10 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 11 |
|
| 12 |
from core.database import get_db
|
| 13 |
+
from core.models import User
|
| 14 |
from services.auth_service.jwt_provider import (
|
| 15 |
verify_access_token,
|
| 16 |
TokenExpiredError,
|
|
|
|
| 20 |
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
async def get_current_user(
|
| 25 |
req: Request,
|
|
|
|
| 157 |
return None
|
| 158 |
|
| 159 |
return user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/dependencies/rate_limit.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Rate Limiting Dependencies
|
| 3 |
+
|
| 4 |
+
Functions for checking and enforcing rate limits on API endpoints.
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from sqlalchemy import select, and_
|
| 9 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 10 |
+
|
| 11 |
+
from core.models import RateLimit
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def check_rate_limit(
|
| 17 |
+
db: AsyncSession,
|
| 18 |
+
identifier: str,
|
| 19 |
+
endpoint: str,
|
| 20 |
+
limit: int,
|
| 21 |
+
window_minutes: int
|
| 22 |
+
) -> bool:
|
| 23 |
+
"""
|
| 24 |
+
Check if request is within rate limits.
|
| 25 |
+
Returns True if allowed, False if limit exceeded.
|
| 26 |
+
"""
|
| 27 |
+
now = datetime.utcnow()
|
| 28 |
+
window_start = now - timedelta(minutes=window_minutes)
|
| 29 |
+
|
| 30 |
+
# Check existing limit (get most recent if multiple exist)
|
| 31 |
+
query = select(RateLimit).where(
|
| 32 |
+
and_(
|
| 33 |
+
RateLimit.identifier == identifier,
|
| 34 |
+
RateLimit.endpoint == endpoint,
|
| 35 |
+
RateLimit.window_start >= window_start
|
| 36 |
+
)
|
| 37 |
+
).order_by(RateLimit.window_start.desc())
|
| 38 |
+
result = await db.execute(query)
|
| 39 |
+
rate_limit = result.scalars().first()
|
| 40 |
+
|
| 41 |
+
if rate_limit:
|
| 42 |
+
if rate_limit.attempts >= limit:
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
# Increment attempts
|
| 46 |
+
rate_limit.attempts += 1
|
| 47 |
+
await db.commit()
|
| 48 |
+
return True
|
| 49 |
+
else:
|
| 50 |
+
# Create new rate limit record
|
| 51 |
+
new_limit = RateLimit(
|
| 52 |
+
identifier=identifier,
|
| 53 |
+
endpoint=endpoint,
|
| 54 |
+
attempts=1,
|
| 55 |
+
window_start=now,
|
| 56 |
+
expires_at=now + timedelta(minutes=window_minutes)
|
| 57 |
+
)
|
| 58 |
+
db.add(new_limit)
|
| 59 |
+
await db.commit()
|
| 60 |
+
return True
|
core/utils/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core Utilities
|
| 3 |
+
|
| 4 |
+
Utility functions for the application.
|
| 5 |
+
"""
|
| 6 |
+
from .geolocation import get_geolocation
|
| 7 |
+
|
| 8 |
+
__all__ = ["get_geolocation"]
|
core/utils/geolocation.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geolocation Utilities
|
| 3 |
+
|
| 4 |
+
Utilities for IP address geolocation lookup.
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Tuple, Optional
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Geolocation API settings
|
| 13 |
+
GEOLOCATION_API_URL = "http://ip-api.com/json/{ip}?fields=status,country,regionName"
|
| 14 |
+
GEOLOCATION_TIMEOUT = 2.0 # seconds
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 18 |
+
"""
|
| 19 |
+
Get country and region for an IP address using ip-api.com.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
ip_address: IPv4 or IPv6 address
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Tuple of (country, region) or (None, None) if lookup fails
|
| 26 |
+
"""
|
| 27 |
+
if not ip_address:
|
| 28 |
+
return None, None
|
| 29 |
+
|
| 30 |
+
# Skip geolocation for localhost/private IPs
|
| 31 |
+
if ip_address in ("127.0.0.1", "::1", "localhost") or ip_address.startswith(("192.168.", "10.", "172.")):
|
| 32 |
+
return None, None
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
async with httpx.AsyncClient(timeout=GEOLOCATION_TIMEOUT) as client:
|
| 36 |
+
response = await client.get(GEOLOCATION_API_URL.format(ip=ip_address))
|
| 37 |
+
if response.status_code == 200:
|
| 38 |
+
data = response.json()
|
| 39 |
+
if data.get("status") == "success":
|
| 40 |
+
return data.get("country"), data.get("regionName")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.warning(f"Geolocation lookup failed for {ip_address}: {e}")
|
| 43 |
+
|
| 44 |
+
return None, None
|
routers/auth.py
CHANGED
|
@@ -36,7 +36,7 @@ from services.auth_service.jwt_provider import (
|
|
| 36 |
get_jwt_service,
|
| 37 |
InvalidTokenError as JWTInvalidTokenError,
|
| 38 |
)
|
| 39 |
-
from dependencies import check_rate_limit, get_current_user
|
| 40 |
from services.drive_service import DriveService
|
| 41 |
from services.audit_service import AuditService
|
| 42 |
|
|
|
|
| 36 |
get_jwt_service,
|
| 37 |
InvalidTokenError as JWTInvalidTokenError,
|
| 38 |
)
|
| 39 |
+
from core.dependencies import check_rate_limit, get_current_user
|
| 40 |
from services.drive_service import DriveService
|
| 41 |
from services.audit_service import AuditService
|
| 42 |
|
routers/blink.py
CHANGED
|
@@ -11,7 +11,7 @@ import logging
|
|
| 11 |
from core.database import get_db
|
| 12 |
from core.models import User, GeminiJob, Contact, ClientUser
|
| 13 |
from services.encryption_service import decrypt_multiple_blocks
|
| 14 |
-
from
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
| 11 |
from core.database import get_db
|
| 12 |
from core.models import User, GeminiJob, Contact, ClientUser
|
| 13 |
from services.encryption_service import decrypt_multiple_blocks
|
| 14 |
+
from core.utils import get_geolocation
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
tests/test_auth_router.py
CHANGED
|
@@ -401,7 +401,7 @@ class TestGetCurrentUserInfo:
|
|
| 401 |
"""GET /me returns authenticated user info."""
|
| 402 |
from routers.auth import router
|
| 403 |
from fastapi import FastAPI
|
| 404 |
-
from dependencies import get_current_user
|
| 405 |
from core.models import User
|
| 406 |
|
| 407 |
app = FastAPI()
|
|
@@ -655,7 +655,7 @@ class TestLogout:
|
|
| 655 |
"""Logout increments user's token version."""
|
| 656 |
from routers.auth import router
|
| 657 |
from fastapi import FastAPI
|
| 658 |
-
from dependencies import get_current_user
|
| 659 |
from core.database import get_db
|
| 660 |
from core.models import User
|
| 661 |
|
|
@@ -688,7 +688,7 @@ class TestLogout:
|
|
| 688 |
"""Logout deletes refresh token cookie."""
|
| 689 |
from routers.auth import router
|
| 690 |
from fastapi import FastAPI
|
| 691 |
-
from dependencies import get_current_user
|
| 692 |
from core.database import get_db
|
| 693 |
from core.models import User
|
| 694 |
|
|
|
|
| 401 |
"""GET /me returns authenticated user info."""
|
| 402 |
from routers.auth import router
|
| 403 |
from fastapi import FastAPI
|
| 404 |
+
from core.dependencies import get_current_user
|
| 405 |
from core.models import User
|
| 406 |
|
| 407 |
app = FastAPI()
|
|
|
|
| 655 |
"""Logout increments user's token version."""
|
| 656 |
from routers.auth import router
|
| 657 |
from fastapi import FastAPI
|
| 658 |
+
from core.dependencies import get_current_user
|
| 659 |
from core.database import get_db
|
| 660 |
from core.models import User
|
| 661 |
|
|
|
|
| 688 |
"""Logout deletes refresh token cookie."""
|
| 689 |
from routers.auth import router
|
| 690 |
from fastapi import FastAPI
|
| 691 |
+
from core.dependencies import get_current_user
|
| 692 |
from core.database import get_db
|
| 693 |
from core.models import User
|
| 694 |
|
tests/test_dependencies.py
CHANGED
|
@@ -24,7 +24,7 @@ class TestGetCurrentUser:
|
|
| 24 |
@pytest.mark.asyncio
|
| 25 |
async def test_valid_token_returns_user(self, db_session):
|
| 26 |
"""Valid JWT token returns authenticated user."""
|
| 27 |
-
from dependencies import get_current_user
|
| 28 |
from core.models import User
|
| 29 |
|
| 30 |
# Create user
|
|
@@ -51,7 +51,7 @@ class TestGetCurrentUser:
|
|
| 51 |
@pytest.mark.asyncio
|
| 52 |
async def test_missing_auth_header_raises_401(self, db_session):
|
| 53 |
"""Missing Authorization header raises 401."""
|
| 54 |
-
from dependencies import get_current_user
|
| 55 |
|
| 56 |
mock_request = MagicMock(spec=Request)
|
| 57 |
mock_request.headers.get.return_value = None
|
|
@@ -64,7 +64,7 @@ class TestGetCurrentUser:
|
|
| 64 |
@pytest.mark.asyncio
|
| 65 |
async def test_invalid_header_format_raises_401(self, db_session):
|
| 66 |
"""Invalid Authorization header format raises 401."""
|
| 67 |
-
from dependencies import get_current_user
|
| 68 |
|
| 69 |
mock_request = MagicMock(spec=Request)
|
| 70 |
mock_request.headers.get.return_value = "InvalidFormat token123"
|
|
@@ -77,7 +77,7 @@ class TestGetCurrentUser:
|
|
| 77 |
@pytest.mark.asyncio
|
| 78 |
async def test_expired_token_raises_401(self, db_session):
|
| 79 |
"""Expired JWT token raises 401."""
|
| 80 |
-
from dependencies import get_current_user
|
| 81 |
from services.auth_service.jwt_provider import TokenExpiredError
|
| 82 |
|
| 83 |
mock_request = MagicMock(spec=Request)
|
|
@@ -94,7 +94,7 @@ class TestGetCurrentUser:
|
|
| 94 |
@pytest.mark.asyncio
|
| 95 |
async def test_invalid_token_raises_401(self, db_session):
|
| 96 |
"""Invalid JWT token raises 401."""
|
| 97 |
-
from dependencies import get_current_user
|
| 98 |
from services.auth_service.jwt_provider import InvalidTokenError
|
| 99 |
|
| 100 |
mock_request = MagicMock(spec=Request)
|
|
@@ -111,7 +111,7 @@ class TestGetCurrentUser:
|
|
| 111 |
@pytest.mark.asyncio
|
| 112 |
async def test_token_version_mismatch_raises_401(self, db_session):
|
| 113 |
"""Mismatched token version (after logout) raises 401."""
|
| 114 |
-
from dependencies import get_current_user
|
| 115 |
from core.models import User
|
| 116 |
|
| 117 |
# User has token_version=2 (logged out)
|
|
@@ -147,7 +147,7 @@ class TestRateLimitDependency:
|
|
| 147 |
@pytest.mark.asyncio
|
| 148 |
async def test_rate_limit_function_exists(self, db_session):
|
| 149 |
"""check_rate_limit function is accessible."""
|
| 150 |
-
from dependencies import check_rate_limit
|
| 151 |
|
| 152 |
result = await check_rate_limit(
|
| 153 |
db=db_session,
|
|
@@ -171,7 +171,7 @@ class TestGeolocation:
|
|
| 171 |
@pytest.mark.asyncio
|
| 172 |
async def test_geolocation_with_valid_ip(self):
|
| 173 |
"""Get geolocation for valid IP address."""
|
| 174 |
-
from
|
| 175 |
|
| 176 |
with patch('dependencies.httpx.AsyncClient') as mock_client:
|
| 177 |
# Mock API response
|
|
@@ -193,7 +193,7 @@ class TestGeolocation:
|
|
| 193 |
@pytest.mark.asyncio
|
| 194 |
async def test_geolocation_with_invalid_ip(self):
|
| 195 |
"""Handle invalid IP gracefully."""
|
| 196 |
-
from
|
| 197 |
|
| 198 |
country, region = await get_geolocation("invalid_ip")
|
| 199 |
|
|
@@ -204,7 +204,7 @@ class TestGeolocation:
|
|
| 204 |
@pytest.mark.asyncio
|
| 205 |
async def test_geolocation_with_none_ip(self):
|
| 206 |
"""Handle None IP gracefully."""
|
| 207 |
-
from
|
| 208 |
|
| 209 |
country, region = await get_geolocation(None)
|
| 210 |
|
|
@@ -214,7 +214,7 @@ class TestGeolocation:
|
|
| 214 |
@pytest.mark.asyncio
|
| 215 |
async def test_geolocation_api_failure(self):
|
| 216 |
"""Handle API failure gracefully."""
|
| 217 |
-
from
|
| 218 |
|
| 219 |
with patch('dependencies.httpx.AsyncClient') as mock_client:
|
| 220 |
# Mock API failure
|
|
|
|
| 24 |
@pytest.mark.asyncio
|
| 25 |
async def test_valid_token_returns_user(self, db_session):
|
| 26 |
"""Valid JWT token returns authenticated user."""
|
| 27 |
+
from core.dependencies import get_current_user
|
| 28 |
from core.models import User
|
| 29 |
|
| 30 |
# Create user
|
|
|
|
| 51 |
@pytest.mark.asyncio
|
| 52 |
async def test_missing_auth_header_raises_401(self, db_session):
|
| 53 |
"""Missing Authorization header raises 401."""
|
| 54 |
+
from core.dependencies import get_current_user
|
| 55 |
|
| 56 |
mock_request = MagicMock(spec=Request)
|
| 57 |
mock_request.headers.get.return_value = None
|
|
|
|
| 64 |
@pytest.mark.asyncio
|
| 65 |
async def test_invalid_header_format_raises_401(self, db_session):
|
| 66 |
"""Invalid Authorization header format raises 401."""
|
| 67 |
+
from core.dependencies import get_current_user
|
| 68 |
|
| 69 |
mock_request = MagicMock(spec=Request)
|
| 70 |
mock_request.headers.get.return_value = "InvalidFormat token123"
|
|
|
|
| 77 |
@pytest.mark.asyncio
|
| 78 |
async def test_expired_token_raises_401(self, db_session):
|
| 79 |
"""Expired JWT token raises 401."""
|
| 80 |
+
from core.dependencies import get_current_user
|
| 81 |
from services.auth_service.jwt_provider import TokenExpiredError
|
| 82 |
|
| 83 |
mock_request = MagicMock(spec=Request)
|
|
|
|
| 94 |
@pytest.mark.asyncio
|
| 95 |
async def test_invalid_token_raises_401(self, db_session):
|
| 96 |
"""Invalid JWT token raises 401."""
|
| 97 |
+
from core.dependencies import get_current_user
|
| 98 |
from services.auth_service.jwt_provider import InvalidTokenError
|
| 99 |
|
| 100 |
mock_request = MagicMock(spec=Request)
|
|
|
|
| 111 |
@pytest.mark.asyncio
|
| 112 |
async def test_token_version_mismatch_raises_401(self, db_session):
|
| 113 |
"""Mismatched token version (after logout) raises 401."""
|
| 114 |
+
from core.dependencies import get_current_user
|
| 115 |
from core.models import User
|
| 116 |
|
| 117 |
# User has token_version=2 (logged out)
|
|
|
|
| 147 |
@pytest.mark.asyncio
|
| 148 |
async def test_rate_limit_function_exists(self, db_session):
|
| 149 |
"""check_rate_limit function is accessible."""
|
| 150 |
+
from core.dependencies import check_rate_limit
|
| 151 |
|
| 152 |
result = await check_rate_limit(
|
| 153 |
db=db_session,
|
|
|
|
| 171 |
@pytest.mark.asyncio
|
| 172 |
async def test_geolocation_with_valid_ip(self):
|
| 173 |
"""Get geolocation for valid IP address."""
|
| 174 |
+
from core.utils import get_geolocation
|
| 175 |
|
| 176 |
with patch('dependencies.httpx.AsyncClient') as mock_client:
|
| 177 |
# Mock API response
|
|
|
|
| 193 |
@pytest.mark.asyncio
|
| 194 |
async def test_geolocation_with_invalid_ip(self):
|
| 195 |
"""Handle invalid IP gracefully."""
|
| 196 |
+
from core.utils import get_geolocation
|
| 197 |
|
| 198 |
country, region = await get_geolocation("invalid_ip")
|
| 199 |
|
|
|
|
| 204 |
@pytest.mark.asyncio
|
| 205 |
async def test_geolocation_with_none_ip(self):
|
| 206 |
"""Handle None IP gracefully."""
|
| 207 |
+
from core.utils import get_geolocation
|
| 208 |
|
| 209 |
country, region = await get_geolocation(None)
|
| 210 |
|
|
|
|
| 214 |
@pytest.mark.asyncio
|
| 215 |
async def test_geolocation_api_failure(self):
|
| 216 |
"""Handle API failure gracefully."""
|
| 217 |
+
from core.utils import get_geolocation
|
| 218 |
|
| 219 |
with patch('dependencies.httpx.AsyncClient') as mock_client:
|
| 220 |
# Mock API failure
|
tests/test_gemini_router.py
CHANGED
|
@@ -218,7 +218,7 @@ class TestJobStatusEndpoint:
|
|
| 218 |
"""Return 404 for non-existent job."""
|
| 219 |
from routers.gemini import router
|
| 220 |
from fastapi import FastAPI
|
| 221 |
-
from dependencies import get_current_user
|
| 222 |
from core.database import get_db
|
| 223 |
|
| 224 |
app = FastAPI()
|
|
@@ -248,7 +248,7 @@ class TestJobStatusEndpoint:
|
|
| 248 |
"""Return queued status with position."""
|
| 249 |
from routers.gemini import router
|
| 250 |
from fastapi import FastAPI
|
| 251 |
-
from dependencies import get_current_user
|
| 252 |
from core.database import get_db
|
| 253 |
|
| 254 |
app = FastAPI()
|
|
@@ -291,7 +291,7 @@ class TestJobStatusEndpoint:
|
|
| 291 |
"""Return completed status with output."""
|
| 292 |
from routers.gemini import router
|
| 293 |
from fastapi import FastAPI
|
| 294 |
-
from dependencies import get_current_user
|
| 295 |
from core.database import get_db
|
| 296 |
|
| 297 |
app = FastAPI()
|
|
@@ -331,7 +331,7 @@ class TestJobStatusEndpoint:
|
|
| 331 |
"""Return failed status with error."""
|
| 332 |
from routers.gemini import router
|
| 333 |
from fastapi import FastAPI
|
| 334 |
-
from dependencies import get_current_user
|
| 335 |
from core.database import get_db
|
| 336 |
|
| 337 |
app = FastAPI()
|
|
@@ -392,7 +392,7 @@ class TestDownloadEndpoint:
|
|
| 392 |
"""Return 404 for non-existent job."""
|
| 393 |
from routers.gemini import router
|
| 394 |
from fastapi import FastAPI
|
| 395 |
-
from dependencies import get_current_user
|
| 396 |
from core.database import get_db
|
| 397 |
|
| 398 |
app = FastAPI()
|
|
@@ -420,7 +420,7 @@ class TestDownloadEndpoint:
|
|
| 420 |
"""Return 400 if video not ready."""
|
| 421 |
from routers.gemini import router
|
| 422 |
from fastapi import FastAPI
|
| 423 |
-
from dependencies import get_current_user
|
| 424 |
from core.database import get_db
|
| 425 |
|
| 426 |
app = FastAPI()
|
|
@@ -476,7 +476,7 @@ class TestCancelEndpoint:
|
|
| 476 |
"""Return 404 for non-existent job."""
|
| 477 |
from routers.gemini import router
|
| 478 |
from fastapi import FastAPI
|
| 479 |
-
from dependencies import get_current_user
|
| 480 |
from core.database import get_db
|
| 481 |
|
| 482 |
app = FastAPI()
|
|
@@ -504,7 +504,7 @@ class TestCancelEndpoint:
|
|
| 504 |
"""Only queued jobs can be cancelled."""
|
| 505 |
from routers.gemini import router
|
| 506 |
from fastapi import FastAPI
|
| 507 |
-
from dependencies import get_current_user
|
| 508 |
from core.database import get_db
|
| 509 |
|
| 510 |
app = FastAPI()
|
|
@@ -537,7 +537,7 @@ class TestCancelEndpoint:
|
|
| 537 |
"""Successfully cancel a queued job."""
|
| 538 |
from routers.gemini import router
|
| 539 |
from fastapi import FastAPI
|
| 540 |
-
from dependencies import get_current_user
|
| 541 |
from core.database import get_db
|
| 542 |
|
| 543 |
app = FastAPI()
|
|
|
|
| 218 |
"""Return 404 for non-existent job."""
|
| 219 |
from routers.gemini import router
|
| 220 |
from fastapi import FastAPI
|
| 221 |
+
from core.dependencies import get_current_user
|
| 222 |
from core.database import get_db
|
| 223 |
|
| 224 |
app = FastAPI()
|
|
|
|
| 248 |
"""Return queued status with position."""
|
| 249 |
from routers.gemini import router
|
| 250 |
from fastapi import FastAPI
|
| 251 |
+
from core.dependencies import get_current_user
|
| 252 |
from core.database import get_db
|
| 253 |
|
| 254 |
app = FastAPI()
|
|
|
|
| 291 |
"""Return completed status with output."""
|
| 292 |
from routers.gemini import router
|
| 293 |
from fastapi import FastAPI
|
| 294 |
+
from core.dependencies import get_current_user
|
| 295 |
from core.database import get_db
|
| 296 |
|
| 297 |
app = FastAPI()
|
|
|
|
| 331 |
"""Return failed status with error."""
|
| 332 |
from routers.gemini import router
|
| 333 |
from fastapi import FastAPI
|
| 334 |
+
from core.dependencies import get_current_user
|
| 335 |
from core.database import get_db
|
| 336 |
|
| 337 |
app = FastAPI()
|
|
|
|
| 392 |
"""Return 404 for non-existent job."""
|
| 393 |
from routers.gemini import router
|
| 394 |
from fastapi import FastAPI
|
| 395 |
+
from core.dependencies import get_current_user
|
| 396 |
from core.database import get_db
|
| 397 |
|
| 398 |
app = FastAPI()
|
|
|
|
| 420 |
"""Return 400 if video not ready."""
|
| 421 |
from routers.gemini import router
|
| 422 |
from fastapi import FastAPI
|
| 423 |
+
from core.dependencies import get_current_user
|
| 424 |
from core.database import get_db
|
| 425 |
|
| 426 |
app = FastAPI()
|
|
|
|
| 476 |
"""Return 404 for non-existent job."""
|
| 477 |
from routers.gemini import router
|
| 478 |
from fastapi import FastAPI
|
| 479 |
+
from core.dependencies import get_current_user
|
| 480 |
from core.database import get_db
|
| 481 |
|
| 482 |
app = FastAPI()
|
|
|
|
| 504 |
"""Only queued jobs can be cancelled."""
|
| 505 |
from routers.gemini import router
|
| 506 |
from fastapi import FastAPI
|
| 507 |
+
from core.dependencies import get_current_user
|
| 508 |
from core.database import get_db
|
| 509 |
|
| 510 |
app = FastAPI()
|
|
|
|
| 537 |
"""Successfully cancel a queued job."""
|
| 538 |
from routers.gemini import router
|
| 539 |
from fastapi import FastAPI
|
| 540 |
+
from core.dependencies import get_current_user
|
| 541 |
from core.database import get_db
|
| 542 |
|
| 543 |
app = FastAPI()
|
tests/test_payments_router.py
CHANGED
|
@@ -163,7 +163,7 @@ class TestCreateOrder:
|
|
| 163 |
"""Reject invalid package_id."""
|
| 164 |
from routers.payments import router
|
| 165 |
from fastapi import FastAPI
|
| 166 |
-
from dependencies import get_current_user
|
| 167 |
|
| 168 |
app = FastAPI()
|
| 169 |
|
|
@@ -189,7 +189,7 @@ class TestCreateOrder:
|
|
| 189 |
"""Return 503 if Razorpay not configured."""
|
| 190 |
from routers.payments import router
|
| 191 |
from fastapi import FastAPI
|
| 192 |
-
from dependencies import get_current_user
|
| 193 |
|
| 194 |
app = FastAPI()
|
| 195 |
|
|
@@ -241,7 +241,7 @@ class TestVerifyPayment:
|
|
| 241 |
"""Return 404 for unknown transaction."""
|
| 242 |
from routers.payments import router
|
| 243 |
from fastapi import FastAPI
|
| 244 |
-
from dependencies import get_current_user
|
| 245 |
from core.database import get_db
|
| 246 |
|
| 247 |
app = FastAPI()
|
|
@@ -395,7 +395,7 @@ class TestPaymentHistory:
|
|
| 395 |
"""History returns empty list for user with no transactions."""
|
| 396 |
from routers.payments import router
|
| 397 |
from fastapi import FastAPI
|
| 398 |
-
from dependencies import get_current_user
|
| 399 |
from core.database import get_db
|
| 400 |
|
| 401 |
app = FastAPI()
|
|
@@ -433,7 +433,7 @@ class TestPaymentHistory:
|
|
| 433 |
"""History respects pagination parameters."""
|
| 434 |
from routers.payments import router
|
| 435 |
from fastapi import FastAPI
|
| 436 |
-
from dependencies import get_current_user
|
| 437 |
from core.database import get_db
|
| 438 |
|
| 439 |
app = FastAPI()
|
|
|
|
| 163 |
"""Reject invalid package_id."""
|
| 164 |
from routers.payments import router
|
| 165 |
from fastapi import FastAPI
|
| 166 |
+
from core.dependencies import get_current_user
|
| 167 |
|
| 168 |
app = FastAPI()
|
| 169 |
|
|
|
|
| 189 |
"""Return 503 if Razorpay not configured."""
|
| 190 |
from routers.payments import router
|
| 191 |
from fastapi import FastAPI
|
| 192 |
+
from core.dependencies import get_current_user
|
| 193 |
|
| 194 |
app = FastAPI()
|
| 195 |
|
|
|
|
| 241 |
"""Return 404 for unknown transaction."""
|
| 242 |
from routers.payments import router
|
| 243 |
from fastapi import FastAPI
|
| 244 |
+
from core.dependencies import get_current_user
|
| 245 |
from core.database import get_db
|
| 246 |
|
| 247 |
app = FastAPI()
|
|
|
|
| 395 |
"""History returns empty list for user with no transactions."""
|
| 396 |
from routers.payments import router
|
| 397 |
from fastapi import FastAPI
|
| 398 |
+
from core.dependencies import get_current_user
|
| 399 |
from core.database import get_db
|
| 400 |
|
| 401 |
app = FastAPI()
|
|
|
|
| 433 |
"""History respects pagination parameters."""
|
| 434 |
from routers.payments import router
|
| 435 |
from fastapi import FastAPI
|
| 436 |
+
from core.dependencies import get_current_user
|
| 437 |
from core.database import get_db
|
| 438 |
|
| 439 |
app = FastAPI()
|
tests/test_rate_limiting.py
CHANGED
|
@@ -26,7 +26,7 @@ class TestRateLimitBasics:
|
|
| 26 |
@pytest.mark.asyncio
|
| 27 |
async def test_first_request_allowed(self, db_session):
|
| 28 |
"""First request within limit is allowed."""
|
| 29 |
-
from dependencies import check_rate_limit
|
| 30 |
|
| 31 |
result = await check_rate_limit(
|
| 32 |
db=db_session,
|
|
@@ -41,7 +41,7 @@ class TestRateLimitBasics:
|
|
| 41 |
@pytest.mark.asyncio
|
| 42 |
async def test_within_limit_allowed(self, db_session):
|
| 43 |
"""Requests within limit are allowed."""
|
| 44 |
-
from dependencies import check_rate_limit
|
| 45 |
|
| 46 |
# Make 3 requests (limit is 5)
|
| 47 |
for i in range(3):
|
|
@@ -57,7 +57,7 @@ class TestRateLimitBasics:
|
|
| 57 |
@pytest.mark.asyncio
|
| 58 |
async def test_exceed_limit_blocked(self, db_session):
|
| 59 |
"""Requests exceeding limit are blocked."""
|
| 60 |
-
from dependencies import check_rate_limit
|
| 61 |
|
| 62 |
# Make exactly limit requests
|
| 63 |
for i in range(5):
|
|
@@ -91,7 +91,7 @@ class TestWindowBasedLimiting:
|
|
| 91 |
@pytest.mark.asyncio
|
| 92 |
async def test_rate_limit_creates_window(self, db_session):
|
| 93 |
"""Rate limit creates time window entry."""
|
| 94 |
-
from dependencies import check_rate_limit
|
| 95 |
from core.models import RateLimit
|
| 96 |
|
| 97 |
await check_rate_limit(
|
|
@@ -115,7 +115,7 @@ class TestWindowBasedLimiting:
|
|
| 115 |
@pytest.mark.asyncio
|
| 116 |
async def test_attempts_increment_in_window(self, db_session):
|
| 117 |
"""Attempts increment within same window."""
|
| 118 |
-
from dependencies import check_rate_limit
|
| 119 |
from core.models import RateLimit
|
| 120 |
|
| 121 |
identifier = "10.10.10.10"
|
|
@@ -153,7 +153,7 @@ class TestPerIPAndEndpoint:
|
|
| 153 |
@pytest.mark.asyncio
|
| 154 |
async def test_different_ips_separate_limits(self, db_session):
|
| 155 |
"""Different IPs have separate rate limits."""
|
| 156 |
-
from dependencies import check_rate_limit
|
| 157 |
|
| 158 |
# IP 1 makes 5 requests
|
| 159 |
for i in range(5):
|
|
@@ -188,7 +188,7 @@ class TestPerIPAndEndpoint:
|
|
| 188 |
@pytest.mark.asyncio
|
| 189 |
async def test_different_endpoints_separate_limits(self, db_session):
|
| 190 |
"""Same IP has separate limits for different endpoints."""
|
| 191 |
-
from dependencies import check_rate_limit
|
| 192 |
|
| 193 |
ip = "203.0.113.50"
|
| 194 |
|
|
@@ -233,7 +233,7 @@ class TestRateLimitExpiry:
|
|
| 233 |
@pytest.mark.asyncio
|
| 234 |
async def test_rate_limit_has_expiry(self, db_session):
|
| 235 |
"""Rate limit entry has expiry time."""
|
| 236 |
-
from dependencies import check_rate_limit
|
| 237 |
from core.models import RateLimit
|
| 238 |
|
| 239 |
await check_rate_limit(
|
|
@@ -266,7 +266,7 @@ class TestRateLimitEdgeCases:
|
|
| 266 |
@pytest.mark.asyncio
|
| 267 |
async def test_zero_limit_blocks_all(self, db_session):
|
| 268 |
"""Limit of 0 blocks all requests."""
|
| 269 |
-
from dependencies import check_rate_limit
|
| 270 |
|
| 271 |
# First request with limit=0 should be blocked
|
| 272 |
result = await check_rate_limit(
|
|
@@ -296,7 +296,7 @@ class TestRateLimitEdgeCases:
|
|
| 296 |
@pytest.mark.asyncio
|
| 297 |
async def test_limit_of_one(self, db_session):
|
| 298 |
"""Limit of 1 allows only first request."""
|
| 299 |
-
from dependencies import check_rate_limit
|
| 300 |
|
| 301 |
result1 = await check_rate_limit(
|
| 302 |
db=db_session,
|
|
@@ -319,7 +319,7 @@ class TestRateLimitEdgeCases:
|
|
| 319 |
@pytest.mark.asyncio
|
| 320 |
async def test_very_short_window(self, db_session):
|
| 321 |
"""Very short time window works correctly."""
|
| 322 |
-
from dependencies import check_rate_limit
|
| 323 |
|
| 324 |
# 1 minute window
|
| 325 |
result = await check_rate_limit(
|
|
@@ -335,7 +335,7 @@ class TestRateLimitEdgeCases:
|
|
| 335 |
@pytest.mark.asyncio
|
| 336 |
async def test_long_window(self, db_session):
|
| 337 |
"""Long time window works correctly."""
|
| 338 |
-
from dependencies import check_rate_limit
|
| 339 |
|
| 340 |
# 24 hour window
|
| 341 |
result = await check_rate_limit(
|
|
@@ -359,7 +359,7 @@ class TestRateLimitPersistence:
|
|
| 359 |
@pytest.mark.asyncio
|
| 360 |
async def test_rate_limit_persists(self, db_session):
|
| 361 |
"""Rate limit data persists across checks."""
|
| 362 |
-
from dependencies import check_rate_limit
|
| 363 |
from core.models import RateLimit
|
| 364 |
|
| 365 |
identifier = "192.168.1.99"
|
|
|
|
| 26 |
@pytest.mark.asyncio
|
| 27 |
async def test_first_request_allowed(self, db_session):
|
| 28 |
"""First request within limit is allowed."""
|
| 29 |
+
from core.dependencies import check_rate_limit
|
| 30 |
|
| 31 |
result = await check_rate_limit(
|
| 32 |
db=db_session,
|
|
|
|
| 41 |
@pytest.mark.asyncio
|
| 42 |
async def test_within_limit_allowed(self, db_session):
|
| 43 |
"""Requests within limit are allowed."""
|
| 44 |
+
from core.dependencies import check_rate_limit
|
| 45 |
|
| 46 |
# Make 3 requests (limit is 5)
|
| 47 |
for i in range(3):
|
|
|
|
| 57 |
@pytest.mark.asyncio
|
| 58 |
async def test_exceed_limit_blocked(self, db_session):
|
| 59 |
"""Requests exceeding limit are blocked."""
|
| 60 |
+
from core.dependencies import check_rate_limit
|
| 61 |
|
| 62 |
# Make exactly limit requests
|
| 63 |
for i in range(5):
|
|
|
|
| 91 |
@pytest.mark.asyncio
|
| 92 |
async def test_rate_limit_creates_window(self, db_session):
|
| 93 |
"""Rate limit creates time window entry."""
|
| 94 |
+
from core.dependencies import check_rate_limit
|
| 95 |
from core.models import RateLimit
|
| 96 |
|
| 97 |
await check_rate_limit(
|
|
|
|
| 115 |
@pytest.mark.asyncio
|
| 116 |
async def test_attempts_increment_in_window(self, db_session):
|
| 117 |
"""Attempts increment within same window."""
|
| 118 |
+
from core.dependencies import check_rate_limit
|
| 119 |
from core.models import RateLimit
|
| 120 |
|
| 121 |
identifier = "10.10.10.10"
|
|
|
|
| 153 |
@pytest.mark.asyncio
|
| 154 |
async def test_different_ips_separate_limits(self, db_session):
|
| 155 |
"""Different IPs have separate rate limits."""
|
| 156 |
+
from core.dependencies import check_rate_limit
|
| 157 |
|
| 158 |
# IP 1 makes 5 requests
|
| 159 |
for i in range(5):
|
|
|
|
| 188 |
@pytest.mark.asyncio
|
| 189 |
async def test_different_endpoints_separate_limits(self, db_session):
|
| 190 |
"""Same IP has separate limits for different endpoints."""
|
| 191 |
+
from core.dependencies import check_rate_limit
|
| 192 |
|
| 193 |
ip = "203.0.113.50"
|
| 194 |
|
|
|
|
| 233 |
@pytest.mark.asyncio
|
| 234 |
async def test_rate_limit_has_expiry(self, db_session):
|
| 235 |
"""Rate limit entry has expiry time."""
|
| 236 |
+
from core.dependencies import check_rate_limit
|
| 237 |
from core.models import RateLimit
|
| 238 |
|
| 239 |
await check_rate_limit(
|
|
|
|
| 266 |
@pytest.mark.asyncio
|
| 267 |
async def test_zero_limit_blocks_all(self, db_session):
|
| 268 |
"""Limit of 0 blocks all requests."""
|
| 269 |
+
from core.dependencies import check_rate_limit
|
| 270 |
|
| 271 |
# First request with limit=0 should be blocked
|
| 272 |
result = await check_rate_limit(
|
|
|
|
| 296 |
@pytest.mark.asyncio
|
| 297 |
async def test_limit_of_one(self, db_session):
|
| 298 |
"""Limit of 1 allows only first request."""
|
| 299 |
+
from core.dependencies import check_rate_limit
|
| 300 |
|
| 301 |
result1 = await check_rate_limit(
|
| 302 |
db=db_session,
|
|
|
|
| 319 |
@pytest.mark.asyncio
|
| 320 |
async def test_very_short_window(self, db_session):
|
| 321 |
"""Very short time window works correctly."""
|
| 322 |
+
from core.dependencies import check_rate_limit
|
| 323 |
|
| 324 |
# 1 minute window
|
| 325 |
result = await check_rate_limit(
|
|
|
|
| 335 |
@pytest.mark.asyncio
|
| 336 |
async def test_long_window(self, db_session):
|
| 337 |
"""Long time window works correctly."""
|
| 338 |
+
from core.dependencies import check_rate_limit
|
| 339 |
|
| 340 |
# 24 hour window
|
| 341 |
result = await check_rate_limit(
|
|
|
|
| 359 |
@pytest.mark.asyncio
|
| 360 |
async def test_rate_limit_persists(self, db_session):
|
| 361 |
"""Rate limit data persists across checks."""
|
| 362 |
+
from core.dependencies import check_rate_limit
|
| 363 |
from core.models import RateLimit
|
| 364 |
|
| 365 |
identifier = "192.168.1.99"
|
tests/test_token_expiry_integration.py
CHANGED
|
@@ -197,7 +197,7 @@ class TestTokenVersioning:
|
|
| 197 |
"""Logout increments version, invalidating all existing tokens."""
|
| 198 |
from routers.auth import router
|
| 199 |
from fastapi import FastAPI
|
| 200 |
-
from dependencies import get_current_user
|
| 201 |
from core.database import get_db
|
| 202 |
from core.models import User
|
| 203 |
from services.auth_service.jwt_provider import create_access_token, create_refresh_token
|
|
|
|
| 197 |
"""Logout increments version, invalidating all existing tokens."""
|
| 198 |
from routers.auth import router
|
| 199 |
from fastapi import FastAPI
|
| 200 |
+
from core.dependencies import get_current_user
|
| 201 |
from core.database import get_db
|
| 202 |
from core.models import User
|
| 203 |
from services.auth_service.jwt_provider import create_access_token, create_refresh_token
|