File size: 6,668 Bytes
16ec2cf
 
 
 
 
 
 
 
 
4dcfed0
16ec2cf
 
4dcfed0
16ec2cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dcfed0
16ec2cf
 
 
 
 
4dcfed0
16ec2cf
 
 
 
 
 
 
 
 
 
 
4dcfed0
16ec2cf
 
 
4dcfed0
16ec2cf
 
 
 
 
 
4dcfed0
16ec2cf
 
4dcfed0
16ec2cf
 
4dcfed0
16ec2cf
 
4dcfed0
16ec2cf
 
4dcfed0
 
 
 
 
16ec2cf
 
 
4dcfed0
16ec2cf
 
 
 
 
 
 
 
4dcfed0
16ec2cf
 
 
 
4dcfed0
16ec2cf
 
4dcfed0
16ec2cf
 
 
 
 
 
 
4dcfed0
 
 
16ec2cf
4dcfed0
16ec2cf
 
4dcfed0
16ec2cf
4dcfed0
16ec2cf
 
 
 
 
 
4dcfed0
 
 
16ec2cf
4dcfed0
16ec2cf
 
 
4dcfed0
16ec2cf
 
 
 
4dcfed0
16ec2cf
 
 
 
 
 
 
4dcfed0
16ec2cf
 
4dcfed0
16ec2cf
 
 
 
 
 
 
4dcfed0
16ec2cf
 
 
 
 
 
 
 
 
4dcfed0
 
 
 
 
16ec2cf
 
 
 
 
 
 
 
 
 
 
 
4dcfed0
16ec2cf
 
 
 
 
 
 
 
 
4dcfed0
16ec2cf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
src/utils/rate_limiter.py
Domain-Specific Rate Limiter for Concurrent Web Scraping

Provides thread-safe rate limiting to prevent anti-bot detection when
multiple agents scrape the same domains concurrently.

Usage:
    from src.utils.rate_limiter import RateLimiter, get_rate_limiter

    # Global singleton
    limiter = get_rate_limiter()

    # Acquire before making request
    with limiter.acquire("twitter"):
        # Make request to Twitter
        pass
"""

import threading
import time
import logging
from typing import Dict, Optional
from contextlib import contextmanager
from collections import defaultdict

logger = logging.getLogger("Roger.rate_limiter")


class RateLimiter:
    """
    Thread-safe rate limiter with domain-specific limits.

    Implements a token bucket algorithm with configurable:
    - Requests per minute per domain
    - Maximum concurrent requests per domain
    - Minimum delay between requests to same domain
    """

    # Default configuration per domain (requests_per_minute, max_concurrent, min_delay_seconds)
    DEFAULT_LIMITS = {
        "twitter": {"rpm": 15, "max_concurrent": 2, "min_delay": 2.0},
        "facebook": {"rpm": 10, "max_concurrent": 2, "min_delay": 3.0},
        "linkedin": {"rpm": 10, "max_concurrent": 1, "min_delay": 5.0},
        "instagram": {"rpm": 10, "max_concurrent": 2, "min_delay": 3.0},
        "reddit": {"rpm": 30, "max_concurrent": 3, "min_delay": 1.0},
        "news": {"rpm": 60, "max_concurrent": 5, "min_delay": 0.5},
        "government": {"rpm": 30, "max_concurrent": 3, "min_delay": 1.0},
        "default": {"rpm": 30, "max_concurrent": 3, "min_delay": 1.0},
    }

    def __init__(self, custom_limits: Optional[Dict] = None):
        """
        Initialize rate limiter with optional custom limits.

        Args:
            custom_limits: Optional dict to override default limits per domain
        """
        self._limits = {**self.DEFAULT_LIMITS}
        if custom_limits:
            self._limits.update(custom_limits)

        # Per-domain semaphores for concurrent request limiting
        self._semaphores: Dict[str, threading.Semaphore] = {}

        # Per-domain last request timestamps
        self._last_request: Dict[str, float] = defaultdict(float)

        # Per-domain request counts (for RPM tracking)
        self._request_counts: Dict[str, list] = defaultdict(list)

        # Lock for thread-safe access to shared state
        self._lock = threading.Lock()

        logger.info(
            f"[RateLimiter] Initialized with {len(self._limits)} domain configurations"
        )

    def _get_domain_config(self, domain: str) -> Dict:
        """Get configuration for a domain, falling back to default."""
        return self._limits.get(domain.lower(), self._limits["default"])

    def _get_semaphore(self, domain: str) -> threading.Semaphore:
        """Get or create semaphore for a domain."""
        domain = domain.lower()
        with self._lock:
            if domain not in self._semaphores:
                config = self._get_domain_config(domain)
                self._semaphores[domain] = threading.Semaphore(config["max_concurrent"])
            return self._semaphores[domain]

    def _wait_for_rate_limit(self, domain: str) -> None:
        """Wait if necessary to respect rate limits."""
        domain = domain.lower()
        config = self._get_domain_config(domain)

        with self._lock:
            now = time.time()

            # Enforce minimum delay between requests
            last = self._last_request[domain]
            if last > 0:
                elapsed = now - last
                min_delay = config["min_delay"]
                if elapsed < min_delay:
                    wait_time = min_delay - elapsed
                    logger.debug(
                        f"[RateLimiter] {domain}: waiting {wait_time:.2f}s for min_delay"
                    )
                    time.sleep(wait_time)

            # Clean old request timestamps (older than 60 seconds)
            self._request_counts[domain] = [
                ts for ts in self._request_counts[domain] if now - ts < 60
            ]

            # Check RPM limit
            rpm_limit = config["rpm"]
            if len(self._request_counts[domain]) >= rpm_limit:
                oldest = self._request_counts[domain][0]
                wait_time = 60 - (now - oldest) + 0.1
                if wait_time > 0:
                    logger.warning(
                        f"[RateLimiter] {domain}: RPM limit ({rpm_limit}) reached, waiting {wait_time:.2f}s"
                    )
                    time.sleep(wait_time)

            # Record this request
            self._last_request[domain] = time.time()
            self._request_counts[domain].append(time.time())

    @contextmanager
    def acquire(self, domain: str):
        """
        Context manager to acquire rate limit slot for a domain.

        Usage:
            with limiter.acquire("twitter"):
                # Make request
                pass
        """
        domain = domain.lower()
        semaphore = self._get_semaphore(domain)

        logger.debug(f"[RateLimiter] {domain}: acquiring slot...")
        semaphore.acquire()

        try:
            self._wait_for_rate_limit(domain)
            logger.debug(f"[RateLimiter] {domain}: slot acquired")
            yield
        finally:
            semaphore.release()
            logger.debug(f"[RateLimiter] {domain}: slot released")

    def get_stats(self) -> Dict:
        """Get current rate limiter statistics."""
        with self._lock:
            now = time.time()
            stats = {}
            for domain in self._request_counts:
                recent = [ts for ts in self._request_counts[domain] if now - ts < 60]
                stats[domain] = {
                    "requests_last_minute": len(recent),
                    "last_request_ago": (
                        now - self._last_request[domain]
                        if self._last_request[domain]
                        else None
                    ),
                }
            return stats


# Global singleton instance
_rate_limiter: Optional[RateLimiter] = None
_rate_limiter_lock = threading.Lock()


def get_rate_limiter() -> RateLimiter:
    """Get the global rate limiter singleton."""
    global _rate_limiter

    with _rate_limiter_lock:
        if _rate_limiter is None:
            _rate_limiter = RateLimiter()
        return _rate_limiter


def reset_rate_limiter() -> None:
    """Reset the global rate limiter (useful for testing)."""
    global _rate_limiter

    with _rate_limiter_lock:
        _rate_limiter = None