Spaces:
Running
Running
| from config import RL_REPO_ID as REPO_ID | |
| import datetime | |
| import time | |
| from collections import defaultdict | |
| import threading | |
| import json | |
| import os | |
| from datasets import load_dataset | |
| from huggingface_hub import HfApi | |
| import tempfile | |
| HF_TOKEN = os.getenv("HF_TOKEN") # set this in HF Space secrets | |
| def _json_dumper(obj): | |
| try: | |
| return obj.to_dict() | |
| except: | |
| return obj | |
| class RateLimitConfig: | |
| def __init__(self, max_per_day=None, max_total=None, min_interval_seconds=None): | |
| self.max_per_day = max_per_day | |
| self.max_total = max_total | |
| self.min_interval_seconds = min_interval_seconds | |
| class RateLimitState: | |
| def __init__(self): | |
| self.daily_count = 0 | |
| self.total_count = 0 | |
| self.last_access_time = None | |
| self.last_access_date = datetime.date.today() | |
| def to_dict(self): | |
| return { | |
| "daily_count": self.daily_count, | |
| "total_count": self.total_count, | |
| "last_access_time": self.last_access_time, | |
| "last_access_date": self.last_access_date.strftime("%Y-%m-%d"), | |
| } | |
| def from_dict(data): | |
| state = RateLimitState() | |
| state.daily_count = data.get("daily_count", 0) | |
| state.total_count = data.get("total_count", 0) | |
| state.last_access_time = data.get("last_access_time") | |
| state.last_access_date = data.get("last_access_date").date() | |
| return state | |
| def __repr__(self): | |
| return json.dumps(self.to_dict(), default=_json_dumper) | |
| class RateLimiter: | |
| def __init__(self, config: RateLimitConfig, flush_interval=30): | |
| """ | |
| flush_interval = how often (seconds) to push logs to HF dataset | |
| """ | |
| self.config = config | |
| self.user_log = defaultdict(RateLimitState) | |
| self.lock = threading.Lock() | |
| self._dirty = False | |
| self._stop_event = threading.Event() | |
| self.flush_interval = flush_interval | |
| self.load_state() | |
| # Start background thread for periodic flushing | |
| self.flush_thread = threading.Thread(target=self._flush_loop, daemon=True) | |
| self.flush_thread.start() | |
| def _today(self): | |
| return datetime.date.today() | |
| def _reset_daily_count_if_needed(self, state: RateLimitState): | |
| today = self._today() | |
| if state.last_access_date != today: | |
| state.daily_count = 0 | |
| state.last_access_date = today | |
| def is_allowed(self, user_id: str) -> bool: | |
| now = time.time() | |
| with self.lock: | |
| state = self.user_log[user_id] | |
| self._reset_daily_count_if_needed(state) | |
| # Check min time between accesses | |
| if self.config.min_interval_seconds and state.last_access_time is not None: | |
| if now - state.last_access_time < self.config.min_interval_seconds: | |
| return False, "min_interval_seconds" | |
| # Check daily limit | |
| if self.config.max_per_day and state.daily_count >= self.config.max_per_day: | |
| return False, "max_per_day" | |
| # Check total limit | |
| if self.config.max_total and state.total_count >= self.config.max_total: | |
| return False, "max_total" | |
| # All checks passed, update counters | |
| state.last_access_time = now | |
| state.daily_count += 1 | |
| state.total_count += 1 | |
| self._dirty = True # mark logs as modified | |
| return True, "allowed" | |
| def _flush_loop(self): | |
| """Background loop to flush logs periodically.""" | |
| while not self._stop_event.is_set(): | |
| time.sleep(self.flush_interval) | |
| if self._dirty: | |
| self.save_state() | |
| def save_state(self): | |
| """Overwrite dataset with latest logs as JSONL (easy to inspect).""" | |
| with self.lock: | |
| rows = [ | |
| {"user_id": user_id, "state": state.to_dict()} | |
| for user_id, state in self.user_log.items() | |
| ] | |
| if not rows: | |
| return | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| jsonl_path = os.path.join(tmpdir, "logs.jsonl") | |
| with open(jsonl_path, "w", encoding="utf-8") as f: | |
| for row in rows: | |
| f.write(json.dumps(row) + "\n") | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=jsonl_path, | |
| path_in_repo="logs.jsonl", # always overwrite this file | |
| repo_id=REPO_ID, | |
| repo_type="dataset", | |
| commit_message="Update rate limit logs", | |
| ) | |
| print(f"Flushed {len(rows)} users to dataset") | |
| self._dirty = False | |
| def load_state(self): | |
| """Load from logs.jsonl in HF dataset repo.""" | |
| self.user_log.clear() | |
| try: | |
| ds = load_dataset(REPO_ID, data_files="logs.jsonl", split="train", token=HF_TOKEN) | |
| for row in ds: | |
| self.user_log[row["user_id"]] = RateLimitState.from_dict( | |
| row["state"] | |
| ) | |
| print(f"Loaded {len(self.user_log)} users from dataset") | |
| except Exception as e: | |
| print("Starting with empty log (could not load dataset):", e) | |
| def shutdown(self): | |
| """Stop background flusher and do a final save.""" | |
| self._stop_event.set() | |
| self.flush_thread.join(timeout=5) | |
| if self._dirty: | |
| self.save_state() | |
| def get_status(self, user_id: str) -> dict: | |
| now = time.time() | |
| with self.lock: | |
| if user_id not in self.user_log: | |
| return { | |
| "daily_used": 0, | |
| "daily_remaining": self.config.max_per_day, | |
| "total_used": 0, | |
| "total_remaining": self.config.max_total, | |
| "wait_seconds": 0, | |
| } | |
| state = self.user_log[user_id] | |
| self._reset_daily_count_if_needed(state) | |
| remaining_daily = ( | |
| None | |
| if self.config.max_per_day is None | |
| else max(0, self.config.max_per_day - state.daily_count) | |
| ) | |
| remaining_total = ( | |
| None | |
| if self.config.max_total is None | |
| else max(0, self.config.max_total - state.total_count) | |
| ) | |
| wait_time = ( | |
| 0 | |
| if self.config.min_interval_seconds is None | |
| or state.last_access_time is None | |
| else max( | |
| 0, self.config.min_interval_seconds - (now - state.last_access_time) | |
| ) | |
| ) | |
| return { | |
| "daily_used": state.daily_count, | |
| "daily_remaining": remaining_daily, | |
| "total_used": state.total_count, | |
| "total_remaining": remaining_total, | |
| "wait_seconds": round(wait_time, 2), | |
| } |