File size: 4,699 Bytes
b3c30a9
 
 
 
 
 
 
 
 
 
7a57599
b3c30a9
 
7a57599
b3c30a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d9cf76
b3c30a9
4d9cf76
 
 
 
 
 
 
 
b3c30a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d9cf76
b3c30a9
 
 
 
 
 
 
 
 
 
7a57599
b3c30a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d9cf76
b3c30a9
 
 
 
 
 
 
 
 
 
 
7a57599
b3c30a9
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import os
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Tuple
import uuid
import hashlib


class RateLimiter:
    def __init__(self, session_file: str, daily_limit: int, dev_daily_limit: int):
        self.session_file = Path(session_file)
        self.daily_limit = daily_limit
        self.dev_daily_limit = dev_daily_limit
        self.is_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
        
        # Create session file if doesn't exist
        if not self.session_file.exists():
            self._save_data({})
    
    def _load_data(self) -> dict:
        """Load rate limit data from file"""
        try:
            with open(self.session_file, 'r') as f:
                return json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            return {}
    
    def _save_data(self, data: dict):
        """Save rate limit data to file"""
        with open(self.session_file, 'w') as f:
            json.dump(data, f, indent=2)
    
    def _get_device_id(self, request) -> str:
        """Generate consistent device ID from request headers"""
        # Handle Gradio Request object
        try:
            ip = getattr(request, 'client', {}).get('host', 'unknown') if hasattr(request, 'client') else 'unknown'
            headers = getattr(request, 'headers', {}) if hasattr(request, 'headers') else {}
            user_agent = headers.get('user-agent', 'unknown') if isinstance(headers, dict) else 'unknown'
        except:
            ip = 'unknown'
            user_agent = 'unknown'
        
        # Hash to create stable ID
        fingerprint = f"{ip}:{user_agent}"
        return hashlib.sha256(fingerprint.encode()).hexdigest()[:16]
    
    def _get_next_reset(self) -> datetime:
        """Get next midnight UTC"""
        now = datetime.now(timezone.utc)
        tomorrow = now + timedelta(days=1)
        return tomorrow.replace(hour=0, minute=0, second=0, microsecond=0)
    
    def _cleanup_expired(self, data: dict) -> dict:
        """Remove expired entries"""
        now = datetime.now(timezone.utc)
        cleaned = {}
        
        for device_id, info in data.items():
            reset_time = datetime.fromisoformat(info["reset_time"])
            if reset_time > now:
                cleaned[device_id] = info
        
        return cleaned
    
    def check_limit(self, request) -> Tuple[bool, int, datetime]:
        """
        Check if device has exceeded rate limit
        
        Returns:
            (allowed: bool, remaining: int, reset_time: datetime)
        """
        device_id = self._get_device_id(request)
        data = self._load_data()
        data = self._cleanup_expired(data)
        
        limit = self.dev_daily_limit if self.is_dev_mode else self.daily_limit
        now = datetime.now(timezone.utc)
        
        if device_id not in data:
            # New device
            reset_time = self._get_next_reset()
            data[device_id] = {
                "count": 0,
                "reset_time": reset_time.isoformat()
            }
            self._save_data(data)
        
        device_info = data[device_id]
        reset_time = datetime.fromisoformat(device_info["reset_time"])
        
        # Check if reset needed
        if now >= reset_time:
            device_info["count"] = 0
            device_info["reset_time"] = self._get_next_reset().isoformat()
            self._save_data(data)
        
        current_count = device_info["count"]
        remaining = max(0, limit - current_count)
        allowed = current_count < limit
        
        return allowed, remaining, reset_time
    
    def increment(self, request):
        """Increment usage count for device"""
        device_id = self._get_device_id(request)
        data = self._load_data()
        
        if device_id in data:
            data[device_id]["count"] += 1
            self._save_data(data)
    
    def get_limit_message(self, remaining: int, reset_time: datetime) -> str:
        """Generate user-friendly limit message"""
        mode = "DEV" if self.is_dev_mode else "Standard"
        limit = self.dev_daily_limit if self.is_dev_mode else self.daily_limit
        
        if remaining > 0:
            return f"✅ {remaining}/{limit} generations remaining today ({mode} mode)"
        else:
            now = datetime.now(timezone.utc)
            hours_left = int((reset_time - now).total_seconds() / 3600)
            minutes_left = int(((reset_time - now).total_seconds() % 3600) / 60)
            
            return f"❌ Daily limit reached ({limit}/{limit}). Resets in {hours_left}h {minutes_left}m (midnight UTC)"