Files changed (2) hide show
  1. app.py +163 -117
  2. storage.py +78 -56
app.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  import uuid
6
  import os
7
  import html
 
8
 
9
  import pandas as pd
10
  from huggingface_hub import hf_hub_download
@@ -29,26 +30,52 @@ _pool_path = hf_hub_download(
29
  token=RATINGS_APP_TOKEN
30
  )
31
  _pool_df = pd.read_parquet(_pool_path)
32
- _pool_group_dfs = {g: gdf for g, gdf in _pool_df.groupby("group")}
 
 
 
 
 
 
 
 
 
33
  _stats_lock = threading.Lock()
 
 
34
  _stats_last_loaded_at = 0.0
35
- _stats_by_key: dict[str, tuple[int, int]] = {}
36
  _explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
37
 
38
 
39
  def _reload_stats_if_due(force: bool = False):
40
- global _stats_last_loaded_at, _stats_by_key, _explorer_df
41
  now = time.time()
 
42
  if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
43
  return
 
44
  with _stats_lock:
45
  now = time.time()
 
46
  if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
47
  return
48
- _stats_by_key = load_stats_by_md5(
 
49
  repo_id=POOL_REPO_ID,
50
  token=RATINGS_APP_TOKEN,
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
52
  classifier_scores_path = hf_hub_download(
53
  repo_id=POOL_REPO_ID,
54
  filename="classifier_scores.parquet",
@@ -77,62 +104,81 @@ def _reload_stats_if_due(force: bool = False):
77
 
78
  _reload_stats_if_due(force=True)
79
 
80
- def _pool_fetch_pair(group_name: str) -> tuple:
81
- gdf = _pool_group_dfs[group_name]
82
- assert len(gdf) >= 2, f"Not enough rows for group: {group_name}"
83
- md5_keys = gdf["md5"].astype(str)
84
- wins = md5_keys.map(lambda k: _stats_by_key.get(k, (0, 0))[0])
85
- losses = md5_keys.map(lambda k: _stats_by_key.get(k, (0, 0))[1])
86
-
87
- def _pick_from_mask(mask: pd.Series):
88
- candidate_df = gdf[mask]
89
- if len(candidate_df) < 2:
90
- return None
91
- sample = candidate_df.sample(2, replace=False)
92
- return sample.iloc[0], sample.iloc[1]
93
-
94
-
95
- # 1) Repeat the lowest-margin edge participating in a cycle. (To prevent deadlock, stop if all margins are 4+.)
96
- # a) If deadlocked on a cycle with 4+ images and no inner cycles, sample a random missing edge inside the cycle.
97
- # 2) Pair images that both have wins only . (One of them will lose/tie. Stop when there is only one left.)
98
- # 3) Pair images that both have losses only. (One of them will win/tie. Stop when there is only one left.)
99
- # 4) Pair images with only 2 edges.
100
- # 5) X% chance, re-sample an existing edge, inversely proportional to existing number of samples.
101
- # 6) Y% chance, sample a random missing edge between images already sampled.
102
- # 7) Pair an unsampled image with a random sampled image.
103
-
104
- # 2) Pair images that currently have wins-only records.
105
- picked = _pick_from_mask((wins > 0) & (losses == 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if picked is not None:
107
- return picked[0], picked[1], "wins-only"
108
 
109
- # 3) Pair images that currently have losses-only records.
110
- picked = _pick_from_mask((wins == 0) & (losses > 0))
111
  if picked is not None:
112
- return picked[0], picked[1], "losses-only"
113
 
114
- # 4) Pair images that currently have exactly 2 total edges.
115
- vote_totals = wins + losses
116
- picked = _pick_from_mask(vote_totals == 2)
117
  if picked is not None:
118
- return picked[0], picked[1], "total_votes=2"
119
 
120
- # 7) Prefer pairing an unsampled image with a random previously sampled image.
121
- unsampled_mask = vote_totals == 0
122
- if unsampled_mask.any():
123
- unsampled_row = gdf[unsampled_mask].sample(1).iloc[0]
124
- sampled_df = gdf[~unsampled_mask]
125
- if len(sampled_df) >= 1:
126
- sampled_row = sampled_df.sample(1).iloc[0]
127
- else:
128
- sampled_row = gdf.drop(index=unsampled_row.name).sample(1).iloc[0]
129
- return unsampled_row, sampled_row, "unsampled+sampled"
130
 
131
- # 8) Safety fall back to low-vote weighted sampling.
132
- sample_weights = 1.0 / (vote_totals + 1.0)
133
- sample = gdf.sample(2, weights=sample_weights, replace=False)
134
- return sample.iloc[0], sample.iloc[1], "low-vote"
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def _row_image_url(row) -> str:
138
  sample_url = row.get("sample_url")
@@ -148,39 +194,20 @@ DATASETS: dict[str, dict] = {
148
  "fetch_pair": _pool_fetch_pair,
149
  "get_id": lambda row: row["md5"],
150
  "get_image": _row_image_url,
151
- "groups": {g: g for g in sorted(_pool_df["group"].unique())},
152
  },
153
  }
154
  DEFAULT_DATASET = list(DATASETS.keys())[0]
155
 
156
  def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
157
- groups = list(cfg["groups"].keys())
158
  if rating_pref == "all":
159
- return groups
160
- return [g for g in groups if g.endswith(f"_{rating_pref}")]
161
-
162
- def _commit_oldest_pending(state: dict):
163
- pending = state.setdefault("pending", [])
164
- if len(pending) <= 1:
165
- return
166
- oldest = pending.pop(0)
167
- if oldest.get("winner") in ("A", "B"):
168
- _apply_local_stats_update(oldest["winner"], oldest["key_a"], oldest["key_b"])
169
- threading.Thread(target=VOTE_STORAGE.append_vote_row, args=(oldest.copy(), oldest.get("winner")), daemon=True).start()
170
-
171
-
172
- def _apply_local_stats_update(winner: str, key_a: str, key_b: str):
173
- assert winner in ("A", "B")
174
- with _stats_lock:
175
- wins_a, losses_a = _stats_by_key.get(str(key_a), (0, 0))
176
- wins_b, losses_b = _stats_by_key.get(str(key_b), (0, 0))
177
- if winner == "A":
178
- _stats_by_key[str(key_a)] = (wins_a + 1, losses_a)
179
- _stats_by_key[str(key_b)] = (wins_b, losses_b + 1)
180
- else:
181
- _stats_by_key[str(key_a)] = (wins_a, losses_a + 1)
182
- _stats_by_key[str(key_b)] = (wins_b + 1, losses_b)
183
 
 
 
 
 
 
184
 
185
  def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | None = None) -> str:
186
  total_votes = wins + losses
@@ -190,8 +217,8 @@ def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | N
190
 
191
  def _render_current(state: dict, submit_status: str = "") -> tuple:
192
  _reload_stats_if_due()
193
- wins_a, losses_a = _stats_by_key.get(str(state["key_a"]), (0, 0))
194
- wins_b, losses_b = _stats_by_key.get(str(state["key_b"]), (0, 0))
195
  title_a = "Image A"
196
  title_b = "Image B"
197
  img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_a)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
@@ -207,7 +234,6 @@ def _render_current(state: dict, submit_status: str = "") -> tuple:
207
  return img_a_html, img_b_html, link_a, link_b, back_md, group_md, pair_reason_md, status_md, state
208
 
209
 
210
-
211
  def _normalize_rating_pref(pref: str | None) -> str:
212
  return pref if pref in ("safe", "all") else "safe"
213
 
@@ -274,13 +300,12 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
274
  cfg = DATASETS[dataset_name]
275
  groups = _select_groups(cfg, rating_pref)
276
  assert groups, f"No groups for rating preference: {rating_pref}"
 
277
  group = random.choice(groups)
278
- pair_data = cfg["fetch_pair"](cfg["groups"][group])
279
- if len(pair_data) == 3:
280
- row_a, row_b, pair_reason = pair_data
281
- else:
282
- row_a, row_b = pair_data
283
- pair_reason = ""
284
  state.setdefault("session_id", uuid.uuid4().hex)
285
  key_a = cfg["get_id"](row_a)
286
  key_b = cfg["get_id"](row_b)
@@ -295,8 +320,9 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
295
 
296
  def _queue_decision(winner: str | None, state: dict):
297
  assert state.get("session_id"), "Missing session_id: refusing to record vote"
298
- state.setdefault("pending", [])
299
- decision = {
 
300
  "winner": winner,
301
  "key_a": state["key_a"],
302
  "key_b": state["key_b"],
@@ -309,43 +335,63 @@ def _queue_decision(winner: str | None, state: dict):
309
  "group": state["group"],
310
  "pair_reason": state.get("pair_reason", ""),
311
  "session_id": state["session_id"],
312
- }
313
- state["pending"].append(decision)
314
- state["last_decision"] = decision
315
- state["can_go_back"] = True
316
- _commit_oldest_pending(state)
317
 
318
  def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
319
- assert winner in ("A", "B", None)
320
  if _normalize_submit_key(submit_key) != SUBMIT_KEY:
321
  return _render_current(state, "Wrong submission key.")
 
322
  _queue_decision(winner, state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  return new_round(state["dataset"], state["rating_pref"], state)
324
 
325
  def go_back(state: dict) -> tuple:
326
  pending = state.setdefault("pending", [])
327
- if not state.get("can_go_back"):
328
- return _render_current(state)
329
- last = state.get("last_decision")
330
- if not last:
331
- state["can_go_back"] = False
332
- return _render_current(state)
333
- if pending and pending[-1] == last:
334
- pending.pop()
335
- state["can_go_back"] = False
336
- state["last_decision"] = None
337
- state.update(
338
- dataset=last["dataset"],
339
- rating_pref=last["rating_pref"],
340
- key_a=last["key_a"],
341
- key_b=last["key_b"],
342
- id_a=last["id_a"],
343
- id_b=last["id_b"],
344
- url_a=last["url_a"],
345
- url_b=last["url_b"],
346
- group=last["group"],
347
- pair_reason=last.get("pair_reason", ""),
348
- )
 
 
 
 
 
 
 
349
  return _render_current(state)
350
 
351
  # -- UI ---------------------------------------------------------------------
 
5
  import uuid
6
  import os
7
  import html
8
+ import sys
9
 
10
  import pandas as pd
11
  from huggingface_hub import hf_hub_download
 
30
  token=RATINGS_APP_TOKEN
31
  )
32
  _pool_df = pd.read_parquet(_pool_path)
33
+ _pool_df["wins"] = 0
34
+ _pool_df["losses"] = 0
35
+ _pool_df["votes"] = 0
36
+
37
+ WINS_LOC = _pool_df.columns.get_loc("wins")
38
+ LOSSES_LOC = _pool_df.columns.get_loc("losses")
39
+ VOTES_LOC = _pool_df.columns.get_loc("votes")
40
+
41
+ _md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) }
42
+
43
  _stats_lock = threading.Lock()
44
+ _pool_lock = threading.Lock()
45
+
46
  _stats_last_loaded_at = 0.0
 
47
  _explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
48
 
49
 
50
  def _reload_stats_if_due(force: bool = False):
51
+ global _stats_last_loaded_at,_explorer_df
52
  now = time.time()
53
+
54
  if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
55
  return
56
+
57
  with _stats_lock:
58
  now = time.time()
59
+
60
  if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
61
  return
62
+
63
+ stats_by_key = load_stats_by_md5(
64
  repo_id=POOL_REPO_ID,
65
  token=RATINGS_APP_TOKEN,
66
  )
67
+
68
+ with _pool_lock:
69
+ n_missing = 0
70
+ for md5, stats in stats_by_key.items():
71
+ if (idx := _md5_to_idx.get(md5)) is not None:
72
+ _pool_df.iloc[idx, [WINS_LOC, LOSSES_LOC, VOTES_LOC]] = (*stats, stats[0] + stats[1])
73
+ else:
74
+ n_missing += 1
75
+
76
+ if n_missing:
77
+ print(f"{n_missing} md5s have stats but are not in the pool!", file=sys.stderr)
78
+
79
  classifier_scores_path = hf_hub_download(
80
  repo_id=POOL_REPO_ID,
81
  filename="classifier_scores.parquet",
 
104
 
105
  _reload_stats_if_due(force=True)
106
 
107
+ def _pick_from_bins(df: pd.DataFrame, field: str) -> tuple[pd.Series, pd.Series, int] | None:
108
+ if len(df) < 2:
109
+ return None
110
+
111
+ least = df[field].min()
112
+ if least >= 10:
113
+ return None # don't push too hard for a total order
114
+
115
+ remaining = (df[field] < 10).sum() - 1
116
+
117
+ candidates = df[df[field] == least]
118
+ if len(candidates) > 1:
119
+ sample = candidates.sample(2, replace=False)
120
+ return sample.iloc[0], sample.iloc[1], remaining
121
+
122
+ first = candidates.iloc[0]
123
+ while True:
124
+ least += 1
125
+ candidates = df[df[field] == least]
126
+ if candidates.empty:
127
+ continue
128
+
129
+ sample = candidates.sample(1)
130
+ return first, sample.iloc[0], remaining
131
+
132
+ def _pick_from(df: pd.DataFrame, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None:
133
+ if len(df) < 2:
134
+ return None
135
+
136
+ remaining = len(df) - 2
137
+
138
+ sample = df.sample(2, weights=weights, replace=False)
139
+ return sample.iloc[0], sample.iloc[1], remaining
140
+
141
+ def _pool_fetch_pair(group: str) -> tuple[pd.Series, pd.Series, int, str]:
142
+ gdf = _pool_df[_pool_df["group"] == group]
143
+ ranked = gdf[gdf["votes"] > 0]
144
+
145
+ # 1) Pair images that have wins-only records.
146
+ picked = _pick_from_bins(ranked[ranked["losses"] == 0], "wins")
147
  if picked is not None:
148
+ return *picked, "wins-only"
149
 
150
+ # 2) Pair images that have losses-only records.
151
+ picked = _pick_from_bins(ranked[ranked["wins"] == 0], "losses")
152
  if picked is not None:
153
+ return *picked, "losses-only"
154
 
155
+ # 3) Ensure a minimum density of 3.
156
+ picked = _pick_from(ranked[ranked["votes"] == 2])
 
157
  if picked is not None:
158
+ return *picked, "sparse"
159
 
160
+ # 4) Introduce a new image.
161
+ if ranked.empty or random.random() < 0.75:
162
+ unranked = gdf[gdf["votes"] == 0]
 
 
 
 
 
 
 
163
 
164
+ if ranked.empty: # Very first vote.
165
+ picked = _pick_from(unranked)
166
+ if picked is None:
167
+ raise ValueError("Group is empty.")
168
 
169
+ return *picked, "init"
170
+
171
+ if not unranked.empty:
172
+ return (
173
+ ranked.sample(1).iloc[0],
174
+ unranked.sample(1).iloc[0],
175
+ len(unranked) - 1, "new"
176
+ )
177
+
178
+ # 5) Vote-weighted random sampling.
179
+ picked = _pick_from(ranked, weights=(1.0 / ranked["votes"]))
180
+ assert picked is not None
181
+ return *picked, "random"
182
 
183
  def _row_image_url(row) -> str:
184
  sample_url = row.get("sample_url")
 
194
  "fetch_pair": _pool_fetch_pair,
195
  "get_id": lambda row: row["md5"],
196
  "get_image": _row_image_url,
197
+ "groups": sorted(_pool_df["group"].unique()),
198
  },
199
  }
200
  DEFAULT_DATASET = list(DATASETS.keys())[0]
201
 
202
  def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
 
203
  if rating_pref == "all":
204
+ return cfg["groups"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ return [
207
+ g
208
+ for g in cfg["groups"]
209
+ if g.endswith(f"_{rating_pref}")
210
+ ]
211
 
212
  def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | None = None) -> str:
213
  total_votes = wins + losses
 
217
 
218
  def _render_current(state: dict, submit_status: str = "") -> tuple:
219
  _reload_stats_if_due()
220
+ wins_a, losses_a = _pool_df.iloc[_md5_to_idx[state["key_a"]], [WINS_LOC, LOSSES_LOC]]
221
+ wins_b, losses_b = _pool_df.iloc[_md5_to_idx[state["key_b"]], [WINS_LOC, LOSSES_LOC]]
222
  title_a = "Image A"
223
  title_b = "Image B"
224
  img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_a)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
 
234
  return img_a_html, img_b_html, link_a, link_b, back_md, group_md, pair_reason_md, status_md, state
235
 
236
 
 
237
  def _normalize_rating_pref(pref: str | None) -> str:
238
  return pref if pref in ("safe", "all") else "safe"
239
 
 
300
  cfg = DATASETS[dataset_name]
301
  groups = _select_groups(cfg, rating_pref)
302
  assert groups, f"No groups for rating preference: {rating_pref}"
303
+
304
  group = random.choice(groups)
305
+ row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group)
306
+
307
+ pair_reason = f"{pair_reason} ({reason_remaining})"
308
+
 
 
309
  state.setdefault("session_id", uuid.uuid4().hex)
310
  key_a = cfg["get_id"](row_a)
311
  key_b = cfg["get_id"](row_b)
 
320
 
321
  def _queue_decision(winner: str | None, state: dict):
322
  assert state.get("session_id"), "Missing session_id: refusing to record vote"
323
+
324
+ pending = state.setdefault("pending", [])
325
+ pending.append({
326
  "winner": winner,
327
  "key_a": state["key_a"],
328
  "key_b": state["key_b"],
 
335
  "group": state["group"],
336
  "pair_reason": state.get("pair_reason", ""),
337
  "session_id": state["session_id"],
338
+ })
339
+
340
+ if len(pending) > 1:
341
+ VOTE_STORAGE.queue_row(pending.pop(0))
 
342
 
343
  def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
 
344
  if _normalize_submit_key(submit_key) != SUBMIT_KEY:
345
  return _render_current(state, "Wrong submission key.")
346
+
347
  _queue_decision(winner, state)
348
+
349
+ with _pool_lock:
350
+ match winner:
351
+ case "A":
352
+ _pool_df.iloc[_md5_to_idx[state["key_a"]], [WINS_LOC, VOTES_LOC]] += 1
353
+ _pool_df.iloc[_md5_to_idx[state["key_b"]], [LOSSES_LOC, VOTES_LOC]] += 1
354
+ case "B":
355
+ _pool_df.iloc[_md5_to_idx[state["key_b"]], [WINS_LOC, VOTES_LOC]] += 1
356
+ _pool_df.iloc[_md5_to_idx[state["key_a"]], [LOSSES_LOC, VOTES_LOC]] += 1
357
+ case None:
358
+ pass
359
+ case _:
360
+ raise AssertionError
361
+
362
  return new_round(state["dataset"], state["rating_pref"], state)
363
 
364
  def go_back(state: dict) -> tuple:
365
  pending = state.setdefault("pending", [])
366
+
367
+ if pending:
368
+ last = pending.pop()
369
+ state.update(
370
+ dataset=last["dataset"],
371
+ rating_pref=last["rating_pref"],
372
+ key_a=last["key_a"],
373
+ key_b=last["key_b"],
374
+ id_a=last["id_a"],
375
+ id_b=last["id_b"],
376
+ url_a=last["url_a"],
377
+ url_b=last["url_b"],
378
+ group=last["group"],
379
+ pair_reason=last.get("pair_reason", ""),
380
+ )
381
+
382
+ with _pool_lock:
383
+ match last["winner"]:
384
+ case "A":
385
+ _pool_df.iloc[_md5_to_idx[state["key_a"]], [WINS_LOC, VOTES_LOC]] -= 1
386
+ _pool_df.iloc[_md5_to_idx[state["key_b"]], [LOSSES_LOC, VOTES_LOC]] -= 1
387
+ case "B":
388
+ _pool_df.iloc[_md5_to_idx[state["key_b"]], [WINS_LOC, VOTES_LOC]] -= 1
389
+ _pool_df.iloc[_md5_to_idx[state["key_a"]], [LOSSES_LOC, VOTES_LOC]] -= 1
390
+ case None:
391
+ pass
392
+ case _:
393
+ raise AssertionError
394
+
395
  return _render_current(state)
396
 
397
  # -- UI ---------------------------------------------------------------------
storage.py CHANGED
@@ -30,79 +30,98 @@ class VoteStorage:
30
  def __init__(self, mode: str, token: str | None = None):
31
  assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
32
  self.mode = mode
33
- self._token = token
34
  is_debug_mode = self.mode == "void"
 
35
  self._flush_every = 3 if is_debug_mode else 50
36
  self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
37
- self._votes_lock = threading.Lock()
 
38
  self._votes_buffer: list[dict] = []
39
- self._stop_event = threading.Event()
 
40
  self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
41
  self._flush_thread.start()
42
- atexit.register(self.close)
43
 
44
- def _hf_token(self) -> str | None:
45
- return self._token
 
46
 
47
  def _empty_votes_df(self) -> pd.DataFrame:
48
  return pd.DataFrame(columns=VOTE_COLUMNS)
49
 
50
  def _upload_votes_batch(self, df: pd.DataFrame, commit_message: str):
51
- assert set(VOTE_COLUMNS).issubset(df.columns), "Missing vote columns in upload batch"
52
  if self.mode == "void":
53
- _ = commit_message
54
  return
 
55
  ts = int(time.time())
56
  shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
57
- api = HfApi(token=self._hf_token())
58
- with NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
59
- tmp_path = tmp.name
60
- try:
61
- df[VOTE_COLUMNS].to_parquet(tmp_path, index=False)
62
- api.upload_file(
63
- path_or_fileobj=tmp_path,
64
- path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
65
- repo_id=VOTES_REPO_ID,
66
- repo_type=VOTES_REPO_TYPE,
67
- commit_message=commit_message,
68
- )
69
- finally:
70
- if os.path.exists(tmp_path):
71
- os.remove(tmp_path)
72
-
73
- def _flush_votes(self, force: bool = False):
74
- with self._votes_lock:
75
- if not self._votes_buffer:
76
- return
77
- if not force and len(self._votes_buffer) < self._flush_every:
78
- return
79
- batch = list(self._votes_buffer)
80
- self._votes_buffer.clear()
81
- incoming = pd.DataFrame(batch)
82
- for col in VOTE_COLUMNS:
83
- if col not in incoming.columns:
84
- incoming[col] = None
85
- self._upload_votes_batch(incoming[VOTE_COLUMNS], commit_message=f"append {len(batch)} vote rows")
86
-
87
- def _flush_loop(self):
88
- while not self._stop_event.wait(self._flush_interval_sec):
89
- self._flush_votes(force=True)
90
-
91
- def close(self):
92
- if self._stop_event.is_set():
93
- return
94
- self._stop_event.set()
95
- self._flush_thread.join(timeout=1.0)
96
- self._flush_votes(force=True)
97
 
98
- def append_vote_row(self, state: dict, winner: str | None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  id_a = int(state["id_a"])
100
  id_b = int(state["id_b"])
101
- winner_md5 = None
102
- if winner == "A":
103
- winner_md5 = state["key_a"]
104
- elif winner == "B":
105
- winner_md5 = state["key_b"]
 
 
 
 
 
 
 
106
  vote_row = {
107
  "vote_id": uuid.uuid4().hex,
108
  "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
@@ -115,6 +134,9 @@ class VoteStorage:
115
  "group": state["group"],
116
  "session_id": state["session_id"],
117
  }
118
- with self._votes_lock:
 
119
  self._votes_buffer.append(vote_row)
120
- self._flush_votes()
 
 
 
30
  def __init__(self, mode: str, token: str | None = None):
31
  assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
32
  self.mode = mode
 
33
  is_debug_mode = self.mode == "void"
34
+
35
  self._flush_every = 3 if is_debug_mode else 50
36
  self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
37
+
38
+ self._shutdown = False
39
  self._votes_buffer: list[dict] = []
40
+
41
+ self._flush_condition = threading.Condition(threading.Lock())
42
  self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
43
  self._flush_thread.start()
 
44
 
45
+ self.hf_api = HfApi(token=token)
46
+
47
+ atexit.register(self.close)
48
 
49
  def _empty_votes_df(self) -> pd.DataFrame:
50
  return pd.DataFrame(columns=VOTE_COLUMNS)
51
 
52
  def _upload_votes_batch(self, df: pd.DataFrame, commit_message: str):
 
53
  if self.mode == "void":
 
54
  return
55
+
56
  ts = int(time.time())
57
  shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ self.hf_api.upload_file(
60
+ path_or_fileobj=df.to_parquet(index=False),
61
+ path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
62
+ repo_id=VOTES_REPO_ID,
63
+ repo_type=VOTES_REPO_TYPE,
64
+ commit_message=commit_message,
65
+ )
66
+
67
+ def _flush_loop(self) -> None:
68
+ while True:
69
+ with self._flush_condition:
70
+ while True:
71
+ if self._shutdown:
72
+ # Flush last batch of votes.
73
+ if self._votes_buffer:
74
+ break
75
+
76
+ return
77
+
78
+ # Have enough votes to flush now.
79
+ if len(self._votes_buffer) >= self._flush_every:
80
+ break
81
+
82
+ # Wait for a notify to flush early or shutdown.
83
+ if not self._flush_condition.wait(self._flush_interval_sec):
84
+ # Interval elapsed. Flush if there is at least one vote.
85
+ if self._votes_buffer:
86
+ break
87
+
88
+ # Atomically take the batch of votes.
89
+ batch = self._votes_buffer
90
+ self._votes_buffer = []
91
+
92
+ assert batch
93
+ batch_df = pd.DataFrame(batch)
94
+ del batch
95
+
96
+ for col in VOTE_COLUMNS:
97
+ if col not in batch_df.columns:
98
+ batch_df[col] = None
99
+
100
+ batch_df = batch_df[VOTE_COLUMNS]
101
+ self._upload_votes_batch(batch_df, commit_message=f"upload {len(batch_df)} vote rows")
102
+
103
+ def close(self) -> None:
104
+ with self._flush_condition:
105
+ self._shutdown = True
106
+ self._flush_condition.notify()
107
+
108
+ self._flush_thread.join()
109
+
110
+ def queue_row(self, state: dict) -> None:
111
  id_a = int(state["id_a"])
112
  id_b = int(state["id_b"])
113
+
114
+ winner_md5: str | None
115
+ match state["winner"]:
116
+ case "A":
117
+ winner_md5 = state["key_a"]
118
+ case "B":
119
+ winner_md5 = state["key_b"]
120
+ case None:
121
+ winner_md5 = None
122
+ case _:
123
+ raise AssertionError
124
+
125
  vote_row = {
126
  "vote_id": uuid.uuid4().hex,
127
  "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
 
134
  "group": state["group"],
135
  "session_id": state["session_id"],
136
  }
137
+
138
+ with self._flush_condition:
139
  self._votes_buffer.append(vote_row)
140
+
141
+ if len(self._votes_buffer) == self._flush_every:
142
+ self._flush_condition.notify()