Spaces:
Running
Running
File size: 4,835 Bytes
124cd9f 74f8675 124cd9f 29de1ae 3d2cda1 124cd9f 3d2cda1 e406439 124cd9f 29de1ae e406439 29de1ae 124cd9f 29de1ae e406439 124cd9f e406439 124cd9f 29de1ae 3d2cda1 124cd9f e406439 29de1ae 3d2cda1 124cd9f 29de1ae e406439 29de1ae e406439 29de1ae e406439 29de1ae e406439 29de1ae e406439 29de1ae e406439 124cd9f e406439 124cd9f e406439 124cd9f e406439 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import os
import atexit
import threading
import time
import uuid
from datetime import datetime, timezone
from tempfile import NamedTemporaryFile
import pandas as pd
from huggingface_hub import HfApi
VOTES_REPO_ID = "taigasan/e6-visual-ratings"
VOTES_REPO_TYPE = "dataset"
VOTES_LOG_SUBDIR = "ratings_log"
VOTE_COLUMNS = [
"vote_id",
"timestamp",
"md5a",
"md5b",
"winner_md5",
"url_a",
"url_b",
"dataset",
"group",
"session_id",
]
class VoteStorage:
def __init__(self, mode: str, token: str | None = None) -> None:
assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
self.mode = mode
is_debug_mode = self.mode == "void"
self._flush_every = 3 if is_debug_mode else 50
self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
self._hf_api = HfApi(token=token)
self._flush_condition = threading.Condition(threading.Lock())
self._sync_event = threading.Event()
self._sync_lock = threading.Lock()
self._votes_buffer: list[dict] = []
self._shutdown = False
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
self._flush_thread.start()
atexit.register(self.close)
def _upload_votes_batch(self, batch: list[dict]) -> None:
assert batch
if self.mode == "void":
return
df = pd.DataFrame(batch)
for col in VOTE_COLUMNS:
if col not in df.columns:
df[col] = None
df = df[VOTE_COLUMNS]
ts = int(time.time())
shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
self._hf_api.upload_file(
path_or_fileobj=df.to_parquet(index=False),
path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
repo_id=VOTES_REPO_ID,
repo_type=VOTES_REPO_TYPE,
commit_message=f"upload {len(df)} vote rows",
)
def _flush_loop(self) -> None:
while True:
with self._flush_condition:
while True:
# Forced sync.
if not self._sync_event.is_set():
if self._votes_buffer:
break
self._sync_event.set()
# Shutdown wanted.
if self._shutdown:
if self._votes_buffer:
break
return
# Have enough votes to flush now.
if len(self._votes_buffer) >= self._flush_every:
break
# Wait for a notify to flush early or shutdown.
if not self._flush_condition.wait(self._flush_interval_sec):
# Interval elapsed. Flush if there is at least one vote.
if self._votes_buffer:
break
# Atomically take the batch of votes.
batch = self._votes_buffer
self._votes_buffer = []
self._upload_votes_batch(batch)
def sync(self) -> None:
with self._sync_lock:
with self._flush_condition:
is_shutdown = self._shutdown
if not is_shutdown:
self._sync_event.clear()
self._flush_condition.notify()
if not is_shutdown:
self._sync_event.wait()
if is_shutdown:
self._flush_thread.join()
def close(self) -> None:
with self._flush_condition:
self._shutdown = True
self._flush_condition.notify()
self._flush_thread.join()
def queue_row(self, state: dict) -> None:
id_a = int(state["id_a"])
id_b = int(state["id_b"])
winner_md5: str | None
match state["winner"]:
case "A":
winner_md5 = state["key_a"]
case "B":
winner_md5 = state["key_b"]
case None:
winner_md5 = None
case _:
raise AssertionError
vote_row = {
"vote_id": uuid.uuid4().hex,
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
"md5a": state["key_a"],
"md5b": state["key_b"],
"winner_md5": winner_md5,
"url_a": f"https://e621.net/posts/{id_a}",
"url_b": f"https://e621.net/posts/{id_b}",
"dataset": state["dataset"],
"group": state["group"],
"session_id": state["session_id"],
}
with self._flush_condition:
self._votes_buffer.append(vote_row)
if len(self._votes_buffer) == self._flush_every:
self._flush_condition.notify()
|