FaceForgeAI_ZeroGPU / rate_limiter.py
VcRlAgent's picture
Refactor for Headshort and Scene Generation using Instant-ID model hosted in Replicate
7a57599
import json
import os
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Tuple
import uuid
import hashlib
class RateLimiter:
def __init__(self, session_file: str, daily_limit: int, dev_daily_limit: int):
self.session_file = Path(session_file)
self.daily_limit = daily_limit
self.dev_daily_limit = dev_daily_limit
self.is_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
# Create session file if doesn't exist
if not self.session_file.exists():
self._save_data({})
def _load_data(self) -> dict:
"""Load rate limit data from file"""
try:
with open(self.session_file, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, FileNotFoundError):
return {}
def _save_data(self, data: dict):
"""Save rate limit data to file"""
with open(self.session_file, 'w') as f:
json.dump(data, f, indent=2)
def _get_device_id(self, request) -> str:
"""Generate consistent device ID from request headers"""
# Handle Gradio Request object
try:
ip = getattr(request, 'client', {}).get('host', 'unknown') if hasattr(request, 'client') else 'unknown'
headers = getattr(request, 'headers', {}) if hasattr(request, 'headers') else {}
user_agent = headers.get('user-agent', 'unknown') if isinstance(headers, dict) else 'unknown'
except:
ip = 'unknown'
user_agent = 'unknown'
# Hash to create stable ID
fingerprint = f"{ip}:{user_agent}"
return hashlib.sha256(fingerprint.encode()).hexdigest()[:16]
def _get_next_reset(self) -> datetime:
"""Get next midnight UTC"""
now = datetime.now(timezone.utc)
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=0, minute=0, second=0, microsecond=0)
def _cleanup_expired(self, data: dict) -> dict:
"""Remove expired entries"""
now = datetime.now(timezone.utc)
cleaned = {}
for device_id, info in data.items():
reset_time = datetime.fromisoformat(info["reset_time"])
if reset_time > now:
cleaned[device_id] = info
return cleaned
def check_limit(self, request) -> Tuple[bool, int, datetime]:
"""
Check if device has exceeded rate limit
Returns:
(allowed: bool, remaining: int, reset_time: datetime)
"""
device_id = self._get_device_id(request)
data = self._load_data()
data = self._cleanup_expired(data)
limit = self.dev_daily_limit if self.is_dev_mode else self.daily_limit
now = datetime.now(timezone.utc)
if device_id not in data:
# New device
reset_time = self._get_next_reset()
data[device_id] = {
"count": 0,
"reset_time": reset_time.isoformat()
}
self._save_data(data)
device_info = data[device_id]
reset_time = datetime.fromisoformat(device_info["reset_time"])
# Check if reset needed
if now >= reset_time:
device_info["count"] = 0
device_info["reset_time"] = self._get_next_reset().isoformat()
self._save_data(data)
current_count = device_info["count"]
remaining = max(0, limit - current_count)
allowed = current_count < limit
return allowed, remaining, reset_time
def increment(self, request):
"""Increment usage count for device"""
device_id = self._get_device_id(request)
data = self._load_data()
if device_id in data:
data[device_id]["count"] += 1
self._save_data(data)
def get_limit_message(self, remaining: int, reset_time: datetime) -> str:
"""Generate user-friendly limit message"""
mode = "DEV" if self.is_dev_mode else "Standard"
limit = self.dev_daily_limit if self.is_dev_mode else self.daily_limit
if remaining > 0:
return f"✅ {remaining}/{limit} generations remaining today ({mode} mode)"
else:
now = datetime.now(timezone.utc)
hours_left = int((reset_time - now).total_seconds() / 3600)
minutes_left = int(((reset_time - now).total_seconds() % 3600) / 60)
return f"❌ Daily limit reached ({limit}/{limit}). Resets in {hours_left}h {minutes_left}m (midnight UTC)"