Spaces:
Running
Running
| import os | |
| import atexit | |
| import threading | |
| import time | |
| import uuid | |
| from datetime import datetime, timezone | |
| from tempfile import NamedTemporaryFile | |
| import pandas as pd | |
| from huggingface_hub import HfApi | |
| VOTES_REPO_ID = "taigasan/e6-visual-ratings" | |
| VOTES_REPO_TYPE = "dataset" | |
| VOTES_LOG_SUBDIR = "ratings_log" | |
| VOTE_COLUMNS = [ | |
| "vote_id", | |
| "timestamp", | |
| "md5a", | |
| "md5b", | |
| "winner_md5", | |
| "url_a", | |
| "url_b", | |
| "dataset", | |
| "group", | |
| "session_id", | |
| ] | |
| class VoteStorage: | |
| def __init__(self, mode: str, token: str | None = None) -> None: | |
| assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}" | |
| self.mode = mode | |
| is_debug_mode = self.mode == "void" | |
| self._flush_every = 3 if is_debug_mode else 50 | |
| self._flush_interval_sec = 15.0 if is_debug_mode else 300.0 | |
| self._hf_api = HfApi(token=token) | |
| self._flush_condition = threading.Condition(threading.Lock()) | |
| self._sync_event = threading.Event() | |
| self._sync_lock = threading.Lock() | |
| self._votes_buffer: list[dict] = [] | |
| self._shutdown = False | |
| self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True) | |
| self._flush_thread.start() | |
| atexit.register(self.close) | |
| def _upload_votes_batch(self, batch: list[dict]) -> None: | |
| assert batch | |
| if self.mode == "void": | |
| return | |
| df = pd.DataFrame(batch) | |
| for col in VOTE_COLUMNS: | |
| if col not in df.columns: | |
| df[col] = None | |
| df = df[VOTE_COLUMNS] | |
| ts = int(time.time()) | |
| shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet" | |
| self._hf_api.upload_file( | |
| path_or_fileobj=df.to_parquet(index=False), | |
| path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}", | |
| repo_id=VOTES_REPO_ID, | |
| repo_type=VOTES_REPO_TYPE, | |
| commit_message=f"upload {len(df)} vote rows", | |
| ) | |
| def _flush_loop(self) -> None: | |
| while True: | |
| with self._flush_condition: | |
| while True: | |
| # Forced sync. | |
| if not self._sync_event.is_set(): | |
| if self._votes_buffer: | |
| break | |
| self._sync_event.set() | |
| # Shutdown wanted. | |
| if self._shutdown: | |
| if self._votes_buffer: | |
| break | |
| return | |
| # Have enough votes to flush now. | |
| if len(self._votes_buffer) >= self._flush_every: | |
| break | |
| # Wait for a notify to flush early or shutdown. | |
| if not self._flush_condition.wait(self._flush_interval_sec): | |
| # Interval elapsed. Flush if there is at least one vote. | |
| if self._votes_buffer: | |
| break | |
| # Atomically take the batch of votes. | |
| batch = self._votes_buffer | |
| self._votes_buffer = [] | |
| self._upload_votes_batch(batch) | |
| def sync(self) -> None: | |
| with self._sync_lock: | |
| with self._flush_condition: | |
| is_shutdown = self._shutdown | |
| if not is_shutdown: | |
| self._sync_event.clear() | |
| self._flush_condition.notify() | |
| if not is_shutdown: | |
| self._sync_event.wait() | |
| if is_shutdown: | |
| self._flush_thread.join() | |
| def close(self) -> None: | |
| with self._flush_condition: | |
| self._shutdown = True | |
| self._flush_condition.notify() | |
| self._flush_thread.join() | |
| def queue_row(self, state: dict) -> None: | |
| id_a = int(state["id_a"]) | |
| id_b = int(state["id_b"]) | |
| winner_md5: str | None | |
| match state["winner"]: | |
| case "A": | |
| winner_md5 = state["key_a"] | |
| case "B": | |
| winner_md5 = state["key_b"] | |
| case None: | |
| winner_md5 = None | |
| case _: | |
| raise AssertionError | |
| vote_row = { | |
| "vote_id": uuid.uuid4().hex, | |
| "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"), | |
| "md5a": state["key_a"], | |
| "md5b": state["key_b"], | |
| "winner_md5": winner_md5, | |
| "url_a": f"https://e621.net/posts/{id_a}", | |
| "url_b": f"https://e621.net/posts/{id_b}", | |
| "dataset": state["dataset"], | |
| "group": state["group"], | |
| "session_id": state["session_id"], | |
| } | |
| with self._flush_condition: | |
| self._votes_buffer.append(vote_row) | |
| if len(self._votes_buffer) == self._flush_every: | |
| self._flush_condition.notify() | |