import json import os from datetime import datetime, timezone, timedelta from pathlib import Path from typing import Tuple import uuid import hashlib class RateLimiter: def __init__(self, session_file: str, daily_limit: int, dev_daily_limit: int): self.session_file = Path(session_file) self.daily_limit = daily_limit self.dev_daily_limit = dev_daily_limit self.is_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true" # Create session file if doesn't exist if not self.session_file.exists(): self._save_data({}) def _load_data(self) -> dict: """Load rate limit data from file""" try: with open(self.session_file, 'r') as f: return json.load(f) except (json.JSONDecodeError, FileNotFoundError): return {} def _save_data(self, data: dict): """Save rate limit data to file""" with open(self.session_file, 'w') as f: json.dump(data, f, indent=2) def _get_device_id(self, request) -> str: """Generate consistent device ID from request headers""" # Handle Gradio Request object try: ip = getattr(request, 'client', {}).get('host', 'unknown') if hasattr(request, 'client') else 'unknown' headers = getattr(request, 'headers', {}) if hasattr(request, 'headers') else {} user_agent = headers.get('user-agent', 'unknown') if isinstance(headers, dict) else 'unknown' except: ip = 'unknown' user_agent = 'unknown' # Hash to create stable ID fingerprint = f"{ip}:{user_agent}" return hashlib.sha256(fingerprint.encode()).hexdigest()[:16] def _get_next_reset(self) -> datetime: """Get next midnight UTC""" now = datetime.now(timezone.utc) tomorrow = now + timedelta(days=1) return tomorrow.replace(hour=0, minute=0, second=0, microsecond=0) def _cleanup_expired(self, data: dict) -> dict: """Remove expired entries""" now = datetime.now(timezone.utc) cleaned = {} for device_id, info in data.items(): reset_time = datetime.fromisoformat(info["reset_time"]) if reset_time > now: cleaned[device_id] = info return cleaned def check_limit(self, request) -> Tuple[bool, int, datetime]: """ Check if device has exceeded rate limit Returns: (allowed: bool, remaining: int, reset_time: datetime) """ device_id = self._get_device_id(request) data = self._load_data() data = self._cleanup_expired(data) limit = self.dev_daily_limit if self.is_dev_mode else self.daily_limit now = datetime.now(timezone.utc) if device_id not in data: # New device reset_time = self._get_next_reset() data[device_id] = { "count": 0, "reset_time": reset_time.isoformat() } self._save_data(data) device_info = data[device_id] reset_time = datetime.fromisoformat(device_info["reset_time"]) # Check if reset needed if now >= reset_time: device_info["count"] = 0 device_info["reset_time"] = self._get_next_reset().isoformat() self._save_data(data) current_count = device_info["count"] remaining = max(0, limit - current_count) allowed = current_count < limit return allowed, remaining, reset_time def increment(self, request): """Increment usage count for device""" device_id = self._get_device_id(request) data = self._load_data() if device_id in data: data[device_id]["count"] += 1 self._save_data(data) def get_limit_message(self, remaining: int, reset_time: datetime) -> str: """Generate user-friendly limit message""" mode = "DEV" if self.is_dev_mode else "Standard" limit = self.dev_daily_limit if self.is_dev_mode else self.daily_limit if remaining > 0: return f"✅ {remaining}/{limit} generations remaining today ({mode} mode)" else: now = datetime.now(timezone.utc) hours_left = int((reset_time - now).total_seconds() / 3600) minutes_left = int(((reset_time - now).total_seconds() % 3600) / 60) return f"❌ Daily limit reached ({limit}/{limit}). Resets in {hours_left}h {minutes_left}m (midnight UTC)"