visualoverload-submit / rate_limiter.py
Paul Gavrikov
fix persistent rate limits
ca10c01
raw
history blame
7.04 kB
from config import RL_REPO_ID as REPO_ID
import datetime
import time
from collections import defaultdict
import threading
import json
import os
from datasets import load_dataset
from huggingface_hub import HfApi
import tempfile
HF_TOKEN = os.getenv("HF_TOKEN") # set this in HF Space secrets
def _json_dumper(obj):
try:
return obj.to_dict()
except:
return obj
class RateLimitConfig:
def __init__(self, max_per_day=None, max_total=None, min_interval_seconds=None):
self.max_per_day = max_per_day
self.max_total = max_total
self.min_interval_seconds = min_interval_seconds
class RateLimitState:
def __init__(self):
self.daily_count = 0
self.total_count = 0
self.last_access_time = None
self.last_access_date = datetime.date.today()
def to_dict(self):
return {
"daily_count": self.daily_count,
"total_count": self.total_count,
"last_access_time": self.last_access_time,
"last_access_date": self.last_access_date.strftime("%Y-%m-%d"),
}
@staticmethod
def from_dict(data):
state = RateLimitState()
state.daily_count = data.get("daily_count", 0)
state.total_count = data.get("total_count", 0)
state.last_access_time = data.get("last_access_time")
state.last_access_date = data.get("last_access_date").date()
return state
def __repr__(self):
return json.dumps(self.to_dict(), default=_json_dumper)
class RateLimiter:
def __init__(self, config: RateLimitConfig, flush_interval=30):
"""
flush_interval = how often (seconds) to push logs to HF dataset
"""
self.config = config
self.user_log = defaultdict(RateLimitState)
self.lock = threading.Lock()
self._dirty = False
self._stop_event = threading.Event()
self.flush_interval = flush_interval
self.load_state()
# Start background thread for periodic flushing
self.flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
self.flush_thread.start()
def _today(self):
return datetime.date.today()
def _reset_daily_count_if_needed(self, state: RateLimitState):
today = self._today()
if state.last_access_date != today:
state.daily_count = 0
state.last_access_date = today
def is_allowed(self, user_id: str) -> bool:
now = time.time()
with self.lock:
state = self.user_log[user_id]
self._reset_daily_count_if_needed(state)
# Check min time between accesses
if self.config.min_interval_seconds and state.last_access_time is not None:
if now - state.last_access_time < self.config.min_interval_seconds:
return False, "min_interval_seconds"
# Check daily limit
if self.config.max_per_day and state.daily_count >= self.config.max_per_day:
return False, "max_per_day"
# Check total limit
if self.config.max_total and state.total_count >= self.config.max_total:
return False, "max_total"
# All checks passed, update counters
state.last_access_time = now
state.daily_count += 1
state.total_count += 1
self._dirty = True # mark logs as modified
return True, "allowed"
def _flush_loop(self):
"""Background loop to flush logs periodically."""
while not self._stop_event.is_set():
time.sleep(self.flush_interval)
if self._dirty:
self.save_state()
def save_state(self):
"""Overwrite dataset with latest logs as JSONL (easy to inspect)."""
with self.lock:
rows = [
{"user_id": user_id, "state": state.to_dict()}
for user_id, state in self.user_log.items()
]
if not rows:
return
with tempfile.TemporaryDirectory() as tmpdir:
jsonl_path = os.path.join(tmpdir, "logs.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row) + "\n")
api = HfApi(token=HF_TOKEN)
api.upload_file(
path_or_fileobj=jsonl_path,
path_in_repo="logs.jsonl", # always overwrite this file
repo_id=REPO_ID,
repo_type="dataset",
commit_message="Update rate limit logs",
)
print(f"Flushed {len(rows)} users to dataset")
self._dirty = False
def load_state(self):
"""Load from logs.jsonl in HF dataset repo."""
self.user_log.clear()
try:
ds = load_dataset(REPO_ID, data_files="logs.jsonl", split="train", token=HF_TOKEN)
for row in ds:
self.user_log[row["user_id"]] = RateLimitState.from_dict(
row["state"]
)
print(f"Loaded {len(self.user_log)} users from dataset")
except Exception as e:
print("Starting with empty log (could not load dataset):", e)
def shutdown(self):
"""Stop background flusher and do a final save."""
self._stop_event.set()
self.flush_thread.join(timeout=5)
if self._dirty:
self.save_state()
def get_status(self, user_id: str) -> dict:
now = time.time()
with self.lock:
if user_id not in self.user_log:
return {
"daily_used": 0,
"daily_remaining": self.config.max_per_day,
"total_used": 0,
"total_remaining": self.config.max_total,
"wait_seconds": 0,
}
state = self.user_log[user_id]
self._reset_daily_count_if_needed(state)
remaining_daily = (
None
if self.config.max_per_day is None
else max(0, self.config.max_per_day - state.daily_count)
)
remaining_total = (
None
if self.config.max_total is None
else max(0, self.config.max_total - state.total_count)
)
wait_time = (
0
if self.config.min_interval_seconds is None
or state.last_access_time is None
else max(
0, self.config.min_interval_seconds - (now - state.last_access_time)
)
)
return {
"daily_used": state.daily_count,
"daily_remaining": remaining_daily,
"total_used": state.total_count,
"total_remaining": remaining_total,
"wait_seconds": round(wait_time, 2),
}