e6-visual-ratings / storage.py
RedHotTensors's picture
Synchronously flush votes before reloading stats to prevent out-of-date reloading.
29de1ae
import os
import atexit
import threading
import time
import uuid
from datetime import datetime, timezone
from tempfile import NamedTemporaryFile
import pandas as pd
from huggingface_hub import HfApi
VOTES_REPO_ID = "taigasan/e6-visual-ratings"
VOTES_REPO_TYPE = "dataset"
VOTES_LOG_SUBDIR = "ratings_log"
VOTE_COLUMNS = [
"vote_id",
"timestamp",
"md5a",
"md5b",
"winner_md5",
"url_a",
"url_b",
"dataset",
"group",
"session_id",
]
class VoteStorage:
def __init__(self, mode: str, token: str | None = None) -> None:
assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
self.mode = mode
is_debug_mode = self.mode == "void"
self._flush_every = 3 if is_debug_mode else 50
self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
self._hf_api = HfApi(token=token)
self._flush_condition = threading.Condition(threading.Lock())
self._sync_event = threading.Event()
self._sync_lock = threading.Lock()
self._votes_buffer: list[dict] = []
self._shutdown = False
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
self._flush_thread.start()
atexit.register(self.close)
def _upload_votes_batch(self, batch: list[dict]) -> None:
assert batch
if self.mode == "void":
return
df = pd.DataFrame(batch)
for col in VOTE_COLUMNS:
if col not in df.columns:
df[col] = None
df = df[VOTE_COLUMNS]
ts = int(time.time())
shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
self._hf_api.upload_file(
path_or_fileobj=df.to_parquet(index=False),
path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
repo_id=VOTES_REPO_ID,
repo_type=VOTES_REPO_TYPE,
commit_message=f"upload {len(df)} vote rows",
)
def _flush_loop(self) -> None:
while True:
with self._flush_condition:
while True:
# Forced sync.
if not self._sync_event.is_set():
if self._votes_buffer:
break
self._sync_event.set()
# Shutdown wanted.
if self._shutdown:
if self._votes_buffer:
break
return
# Have enough votes to flush now.
if len(self._votes_buffer) >= self._flush_every:
break
# Wait for a notify to flush early or shutdown.
if not self._flush_condition.wait(self._flush_interval_sec):
# Interval elapsed. Flush if there is at least one vote.
if self._votes_buffer:
break
# Atomically take the batch of votes.
batch = self._votes_buffer
self._votes_buffer = []
self._upload_votes_batch(batch)
def sync(self) -> None:
with self._sync_lock:
with self._flush_condition:
is_shutdown = self._shutdown
if not is_shutdown:
self._sync_event.clear()
self._flush_condition.notify()
if not is_shutdown:
self._sync_event.wait()
if is_shutdown:
self._flush_thread.join()
def close(self) -> None:
with self._flush_condition:
self._shutdown = True
self._flush_condition.notify()
self._flush_thread.join()
def queue_row(self, state: dict) -> None:
id_a = int(state["id_a"])
id_b = int(state["id_b"])
winner_md5: str | None
match state["winner"]:
case "A":
winner_md5 = state["key_a"]
case "B":
winner_md5 = state["key_b"]
case None:
winner_md5 = None
case _:
raise AssertionError
vote_row = {
"vote_id": uuid.uuid4().hex,
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
"md5a": state["key_a"],
"md5b": state["key_b"],
"winner_md5": winner_md5,
"url_a": f"https://e621.net/posts/{id_a}",
"url_b": f"https://e621.net/posts/{id_b}",
"dataset": state["dataset"],
"group": state["group"],
"session_id": state["session_id"],
}
with self._flush_condition:
self._votes_buffer.append(vote_row)
if len(self._votes_buffer) == self._flush_every:
self._flush_condition.notify()