import os import atexit import threading import time import uuid from datetime import datetime, timezone from tempfile import NamedTemporaryFile import pandas as pd from huggingface_hub import HfApi VOTES_REPO_ID = "taigasan/e6-visual-ratings" VOTES_REPO_TYPE = "dataset" VOTES_LOG_SUBDIR = "ratings_log" VOTE_COLUMNS = [ "vote_id", "timestamp", "md5a", "md5b", "winner_md5", "url_a", "url_b", "dataset", "group", "session_id", ] class VoteStorage: def __init__(self, mode: str, token: str | None = None) -> None: assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}" self.mode = mode is_debug_mode = self.mode == "void" self._flush_every = 3 if is_debug_mode else 50 self._flush_interval_sec = 15.0 if is_debug_mode else 300.0 self._hf_api = HfApi(token=token) self._flush_condition = threading.Condition(threading.Lock()) self._sync_event = threading.Event() self._sync_lock = threading.Lock() self._votes_buffer: list[dict] = [] self._shutdown = False self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True) self._flush_thread.start() atexit.register(self.close) def _upload_votes_batch(self, batch: list[dict]) -> None: assert batch if self.mode == "void": return df = pd.DataFrame(batch) for col in VOTE_COLUMNS: if col not in df.columns: df[col] = None df = df[VOTE_COLUMNS] ts = int(time.time()) shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet" self._hf_api.upload_file( path_or_fileobj=df.to_parquet(index=False), path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}", repo_id=VOTES_REPO_ID, repo_type=VOTES_REPO_TYPE, commit_message=f"upload {len(df)} vote rows", ) def _flush_loop(self) -> None: while True: with self._flush_condition: while True: # Forced sync. if not self._sync_event.is_set(): if self._votes_buffer: break self._sync_event.set() # Shutdown wanted. if self._shutdown: if self._votes_buffer: break return # Have enough votes to flush now. if len(self._votes_buffer) >= self._flush_every: break # Wait for a notify to flush early or shutdown. if not self._flush_condition.wait(self._flush_interval_sec): # Interval elapsed. Flush if there is at least one vote. if self._votes_buffer: break # Atomically take the batch of votes. batch = self._votes_buffer self._votes_buffer = [] self._upload_votes_batch(batch) def sync(self) -> None: with self._sync_lock: with self._flush_condition: is_shutdown = self._shutdown if not is_shutdown: self._sync_event.clear() self._flush_condition.notify() if not is_shutdown: self._sync_event.wait() if is_shutdown: self._flush_thread.join() def close(self) -> None: with self._flush_condition: self._shutdown = True self._flush_condition.notify() self._flush_thread.join() def queue_row(self, state: dict) -> None: id_a = int(state["id_a"]) id_b = int(state["id_b"]) winner_md5: str | None match state["winner"]: case "A": winner_md5 = state["key_a"] case "B": winner_md5 = state["key_b"] case None: winner_md5 = None case _: raise AssertionError vote_row = { "vote_id": uuid.uuid4().hex, "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"), "md5a": state["key_a"], "md5b": state["key_b"], "winner_md5": winner_md5, "url_a": f"https://e621.net/posts/{id_a}", "url_b": f"https://e621.net/posts/{id_b}", "dataset": state["dataset"], "group": state["group"], "session_id": state["session_id"], } with self._flush_condition: self._votes_buffer.append(vote_row) if len(self._votes_buffer) == self._flush_every: self._flush_condition.notify()