Spaces:
Running
Running
Initial changes.
#1
by RedHotTensors - opened
- app.py +163 -117
- storage.py +78 -56
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import time
|
|
| 5 |
import uuid
|
| 6 |
import os
|
| 7 |
import html
|
|
|
|
| 8 |
|
| 9 |
import pandas as pd
|
| 10 |
from huggingface_hub import hf_hub_download
|
|
@@ -29,26 +30,52 @@ _pool_path = hf_hub_download(
|
|
| 29 |
token=RATINGS_APP_TOKEN
|
| 30 |
)
|
| 31 |
_pool_df = pd.read_parquet(_pool_path)
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
_stats_lock = threading.Lock()
|
|
|
|
|
|
|
| 34 |
_stats_last_loaded_at = 0.0
|
| 35 |
-
_stats_by_key: dict[str, tuple[int, int]] = {}
|
| 36 |
_explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
|
| 37 |
|
| 38 |
|
| 39 |
def _reload_stats_if_due(force: bool = False):
|
| 40 |
-
global _stats_last_loaded_at,
|
| 41 |
now = time.time()
|
|
|
|
| 42 |
if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
|
| 43 |
return
|
|
|
|
| 44 |
with _stats_lock:
|
| 45 |
now = time.time()
|
|
|
|
| 46 |
if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
|
| 47 |
return
|
| 48 |
-
|
|
|
|
| 49 |
repo_id=POOL_REPO_ID,
|
| 50 |
token=RATINGS_APP_TOKEN,
|
| 51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
classifier_scores_path = hf_hub_download(
|
| 53 |
repo_id=POOL_REPO_ID,
|
| 54 |
filename="classifier_scores.parquet",
|
|
@@ -77,62 +104,81 @@ def _reload_stats_if_due(force: bool = False):
|
|
| 77 |
|
| 78 |
_reload_stats_if_due(force=True)
|
| 79 |
|
| 80 |
-
def
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
if picked is not None:
|
| 107 |
-
return picked
|
| 108 |
|
| 109 |
-
#
|
| 110 |
-
picked =
|
| 111 |
if picked is not None:
|
| 112 |
-
return picked
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
picked = _pick_from_mask(vote_totals == 2)
|
| 117 |
if picked is not None:
|
| 118 |
-
return picked
|
| 119 |
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
unsampled_row = gdf[unsampled_mask].sample(1).iloc[0]
|
| 124 |
-
sampled_df = gdf[~unsampled_mask]
|
| 125 |
-
if len(sampled_df) >= 1:
|
| 126 |
-
sampled_row = sampled_df.sample(1).iloc[0]
|
| 127 |
-
else:
|
| 128 |
-
sampled_row = gdf.drop(index=unsampled_row.name).sample(1).iloc[0]
|
| 129 |
-
return unsampled_row, sampled_row, "unsampled+sampled"
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def _row_image_url(row) -> str:
|
| 138 |
sample_url = row.get("sample_url")
|
|
@@ -148,39 +194,20 @@ DATASETS: dict[str, dict] = {
|
|
| 148 |
"fetch_pair": _pool_fetch_pair,
|
| 149 |
"get_id": lambda row: row["md5"],
|
| 150 |
"get_image": _row_image_url,
|
| 151 |
-
"groups":
|
| 152 |
},
|
| 153 |
}
|
| 154 |
DEFAULT_DATASET = list(DATASETS.keys())[0]
|
| 155 |
|
| 156 |
def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
|
| 157 |
-
groups = list(cfg["groups"].keys())
|
| 158 |
if rating_pref == "all":
|
| 159 |
-
return groups
|
| 160 |
-
return [g for g in groups if g.endswith(f"_{rating_pref}")]
|
| 161 |
-
|
| 162 |
-
def _commit_oldest_pending(state: dict):
|
| 163 |
-
pending = state.setdefault("pending", [])
|
| 164 |
-
if len(pending) <= 1:
|
| 165 |
-
return
|
| 166 |
-
oldest = pending.pop(0)
|
| 167 |
-
if oldest.get("winner") in ("A", "B"):
|
| 168 |
-
_apply_local_stats_update(oldest["winner"], oldest["key_a"], oldest["key_b"])
|
| 169 |
-
threading.Thread(target=VOTE_STORAGE.append_vote_row, args=(oldest.copy(), oldest.get("winner")), daemon=True).start()
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def _apply_local_stats_update(winner: str, key_a: str, key_b: str):
|
| 173 |
-
assert winner in ("A", "B")
|
| 174 |
-
with _stats_lock:
|
| 175 |
-
wins_a, losses_a = _stats_by_key.get(str(key_a), (0, 0))
|
| 176 |
-
wins_b, losses_b = _stats_by_key.get(str(key_b), (0, 0))
|
| 177 |
-
if winner == "A":
|
| 178 |
-
_stats_by_key[str(key_a)] = (wins_a + 1, losses_a)
|
| 179 |
-
_stats_by_key[str(key_b)] = (wins_b, losses_b + 1)
|
| 180 |
-
else:
|
| 181 |
-
_stats_by_key[str(key_a)] = (wins_a, losses_a + 1)
|
| 182 |
-
_stats_by_key[str(key_b)] = (wins_b + 1, losses_b)
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | None = None) -> str:
|
| 186 |
total_votes = wins + losses
|
|
@@ -190,8 +217,8 @@ def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | N
|
|
| 190 |
|
| 191 |
def _render_current(state: dict, submit_status: str = "") -> tuple:
|
| 192 |
_reload_stats_if_due()
|
| 193 |
-
wins_a, losses_a =
|
| 194 |
-
wins_b, losses_b =
|
| 195 |
title_a = "Image A"
|
| 196 |
title_b = "Image B"
|
| 197 |
img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_a)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
|
|
@@ -207,7 +234,6 @@ def _render_current(state: dict, submit_status: str = "") -> tuple:
|
|
| 207 |
return img_a_html, img_b_html, link_a, link_b, back_md, group_md, pair_reason_md, status_md, state
|
| 208 |
|
| 209 |
|
| 210 |
-
|
| 211 |
def _normalize_rating_pref(pref: str | None) -> str:
|
| 212 |
return pref if pref in ("safe", "all") else "safe"
|
| 213 |
|
|
@@ -274,13 +300,12 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
|
|
| 274 |
cfg = DATASETS[dataset_name]
|
| 275 |
groups = _select_groups(cfg, rating_pref)
|
| 276 |
assert groups, f"No groups for rating preference: {rating_pref}"
|
|
|
|
| 277 |
group = random.choice(groups)
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
row_a, row_b = pair_data
|
| 283 |
-
pair_reason = ""
|
| 284 |
state.setdefault("session_id", uuid.uuid4().hex)
|
| 285 |
key_a = cfg["get_id"](row_a)
|
| 286 |
key_b = cfg["get_id"](row_b)
|
|
@@ -295,8 +320,9 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
|
|
| 295 |
|
| 296 |
def _queue_decision(winner: str | None, state: dict):
|
| 297 |
assert state.get("session_id"), "Missing session_id: refusing to record vote"
|
| 298 |
-
|
| 299 |
-
|
|
|
|
| 300 |
"winner": winner,
|
| 301 |
"key_a": state["key_a"],
|
| 302 |
"key_b": state["key_b"],
|
|
@@ -309,43 +335,63 @@ def _queue_decision(winner: str | None, state: dict):
|
|
| 309 |
"group": state["group"],
|
| 310 |
"pair_reason": state.get("pair_reason", ""),
|
| 311 |
"session_id": state["session_id"],
|
| 312 |
-
}
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
_commit_oldest_pending(state)
|
| 317 |
|
| 318 |
def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
|
| 319 |
-
assert winner in ("A", "B", None)
|
| 320 |
if _normalize_submit_key(submit_key) != SUBMIT_KEY:
|
| 321 |
return _render_current(state, "Wrong submission key.")
|
|
|
|
| 322 |
_queue_decision(winner, state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
return new_round(state["dataset"], state["rating_pref"], state)
|
| 324 |
|
| 325 |
def go_back(state: dict) -> tuple:
|
| 326 |
pending = state.setdefault("pending", [])
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
return _render_current(state)
|
| 350 |
|
| 351 |
# -- UI ---------------------------------------------------------------------
|
|
|
|
| 5 |
import uuid
|
| 6 |
import os
|
| 7 |
import html
|
| 8 |
+
import sys
|
| 9 |
|
| 10 |
import pandas as pd
|
| 11 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 30 |
token=RATINGS_APP_TOKEN
|
| 31 |
)
|
| 32 |
_pool_df = pd.read_parquet(_pool_path)
|
| 33 |
+
_pool_df["wins"] = 0
|
| 34 |
+
_pool_df["losses"] = 0
|
| 35 |
+
_pool_df["votes"] = 0
|
| 36 |
+
|
| 37 |
+
WINS_LOC = _pool_df.columns.get_loc("wins")
|
| 38 |
+
LOSSES_LOC = _pool_df.columns.get_loc("losses")
|
| 39 |
+
VOTES_LOC = _pool_df.columns.get_loc("votes")
|
| 40 |
+
|
| 41 |
+
_md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) }
|
| 42 |
+
|
| 43 |
_stats_lock = threading.Lock()
|
| 44 |
+
_pool_lock = threading.Lock()
|
| 45 |
+
|
| 46 |
_stats_last_loaded_at = 0.0
|
|
|
|
| 47 |
_explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
|
| 48 |
|
| 49 |
|
| 50 |
def _reload_stats_if_due(force: bool = False):
|
| 51 |
+
global _stats_last_loaded_at,_explorer_df
|
| 52 |
now = time.time()
|
| 53 |
+
|
| 54 |
if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
|
| 55 |
return
|
| 56 |
+
|
| 57 |
with _stats_lock:
|
| 58 |
now = time.time()
|
| 59 |
+
|
| 60 |
if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
|
| 61 |
return
|
| 62 |
+
|
| 63 |
+
stats_by_key = load_stats_by_md5(
|
| 64 |
repo_id=POOL_REPO_ID,
|
| 65 |
token=RATINGS_APP_TOKEN,
|
| 66 |
)
|
| 67 |
+
|
| 68 |
+
with _pool_lock:
|
| 69 |
+
n_missing = 0
|
| 70 |
+
for md5, stats in stats_by_key.items():
|
| 71 |
+
if (idx := _md5_to_idx.get(md5)) is not None:
|
| 72 |
+
_pool_df.iloc[idx, [WINS_LOC, LOSSES_LOC, VOTES_LOC]] = (*stats, stats[0] + stats[1])
|
| 73 |
+
else:
|
| 74 |
+
n_missing += 1
|
| 75 |
+
|
| 76 |
+
if n_missing:
|
| 77 |
+
print(f"{n_missing} md5s have stats but are not in the pool!", file=sys.stderr)
|
| 78 |
+
|
| 79 |
classifier_scores_path = hf_hub_download(
|
| 80 |
repo_id=POOL_REPO_ID,
|
| 81 |
filename="classifier_scores.parquet",
|
|
|
|
| 104 |
|
| 105 |
_reload_stats_if_due(force=True)
|
| 106 |
|
| 107 |
+
def _pick_from_bins(df: pd.DataFrame, field: str) -> tuple[pd.Series, pd.Series, int] | None:
|
| 108 |
+
if len(df) < 2:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
least = df[field].min()
|
| 112 |
+
if least >= 10:
|
| 113 |
+
return None # don't push too hard for a total order
|
| 114 |
+
|
| 115 |
+
remaining = (df[field] < 10).sum() - 1
|
| 116 |
+
|
| 117 |
+
candidates = df[df[field] == least]
|
| 118 |
+
if len(candidates) > 1:
|
| 119 |
+
sample = candidates.sample(2, replace=False)
|
| 120 |
+
return sample.iloc[0], sample.iloc[1], remaining
|
| 121 |
+
|
| 122 |
+
first = candidates.iloc[0]
|
| 123 |
+
while True:
|
| 124 |
+
least += 1
|
| 125 |
+
candidates = df[df[field] == least]
|
| 126 |
+
if candidates.empty:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
sample = candidates.sample(1)
|
| 130 |
+
return first, sample.iloc[0], remaining
|
| 131 |
+
|
| 132 |
+
def _pick_from(df: pd.DataFrame, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None:
|
| 133 |
+
if len(df) < 2:
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
remaining = len(df) - 2
|
| 137 |
+
|
| 138 |
+
sample = df.sample(2, weights=weights, replace=False)
|
| 139 |
+
return sample.iloc[0], sample.iloc[1], remaining
|
| 140 |
+
|
| 141 |
+
def _pool_fetch_pair(group: str) -> tuple[pd.Series, pd.Series, int, str]:
|
| 142 |
+
gdf = _pool_df[_pool_df["group"] == group]
|
| 143 |
+
ranked = gdf[gdf["votes"] > 0]
|
| 144 |
+
|
| 145 |
+
# 1) Pair images that have wins-only records.
|
| 146 |
+
picked = _pick_from_bins(ranked[ranked["losses"] == 0], "wins")
|
| 147 |
if picked is not None:
|
| 148 |
+
return *picked, "wins-only"
|
| 149 |
|
| 150 |
+
# 2) Pair images that have losses-only records.
|
| 151 |
+
picked = _pick_from_bins(ranked[ranked["wins"] == 0], "losses")
|
| 152 |
if picked is not None:
|
| 153 |
+
return *picked, "losses-only"
|
| 154 |
|
| 155 |
+
# 3) Ensure a minimum density of 3.
|
| 156 |
+
picked = _pick_from(ranked[ranked["votes"] == 2])
|
|
|
|
| 157 |
if picked is not None:
|
| 158 |
+
return *picked, "sparse"
|
| 159 |
|
| 160 |
+
# 4) Introduce a new image.
|
| 161 |
+
if ranked.empty or random.random() < 0.75:
|
| 162 |
+
unranked = gdf[gdf["votes"] == 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
if ranked.empty: # Very first vote.
|
| 165 |
+
picked = _pick_from(unranked)
|
| 166 |
+
if picked is None:
|
| 167 |
+
raise ValueError("Group is empty.")
|
| 168 |
|
| 169 |
+
return *picked, "init"
|
| 170 |
+
|
| 171 |
+
if not unranked.empty:
|
| 172 |
+
return (
|
| 173 |
+
ranked.sample(1).iloc[0],
|
| 174 |
+
unranked.sample(1).iloc[0],
|
| 175 |
+
len(unranked) - 1, "new"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# 5) Vote-weighted random sampling.
|
| 179 |
+
picked = _pick_from(ranked, weights=(1.0 / ranked["votes"]))
|
| 180 |
+
assert picked is not None
|
| 181 |
+
return *picked, "random"
|
| 182 |
|
| 183 |
def _row_image_url(row) -> str:
|
| 184 |
sample_url = row.get("sample_url")
|
|
|
|
| 194 |
"fetch_pair": _pool_fetch_pair,
|
| 195 |
"get_id": lambda row: row["md5"],
|
| 196 |
"get_image": _row_image_url,
|
| 197 |
+
"groups": sorted(_pool_df["group"].unique()),
|
| 198 |
},
|
| 199 |
}
|
| 200 |
DEFAULT_DATASET = list(DATASETS.keys())[0]
|
| 201 |
|
| 202 |
def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
|
|
|
|
| 203 |
if rating_pref == "all":
|
| 204 |
+
return cfg["groups"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
return [
|
| 207 |
+
g
|
| 208 |
+
for g in cfg["groups"]
|
| 209 |
+
if g.endswith(f"_{rating_pref}")
|
| 210 |
+
]
|
| 211 |
|
| 212 |
def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | None = None) -> str:
|
| 213 |
total_votes = wins + losses
|
|
|
|
| 217 |
|
| 218 |
def _render_current(state: dict, submit_status: str = "") -> tuple:
|
| 219 |
_reload_stats_if_due()
|
| 220 |
+
wins_a, losses_a = _pool_df.iloc[_md5_to_idx[state["key_a"]], [WINS_LOC, LOSSES_LOC]]
|
| 221 |
+
wins_b, losses_b = _pool_df.iloc[_md5_to_idx[state["key_b"]], [WINS_LOC, LOSSES_LOC]]
|
| 222 |
title_a = "Image A"
|
| 223 |
title_b = "Image B"
|
| 224 |
img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_a)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
|
|
|
|
| 234 |
return img_a_html, img_b_html, link_a, link_b, back_md, group_md, pair_reason_md, status_md, state
|
| 235 |
|
| 236 |
|
|
|
|
| 237 |
def _normalize_rating_pref(pref: str | None) -> str:
|
| 238 |
return pref if pref in ("safe", "all") else "safe"
|
| 239 |
|
|
|
|
| 300 |
cfg = DATASETS[dataset_name]
|
| 301 |
groups = _select_groups(cfg, rating_pref)
|
| 302 |
assert groups, f"No groups for rating preference: {rating_pref}"
|
| 303 |
+
|
| 304 |
group = random.choice(groups)
|
| 305 |
+
row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group)
|
| 306 |
+
|
| 307 |
+
pair_reason = f"{pair_reason} ({reason_remaining})"
|
| 308 |
+
|
|
|
|
|
|
|
| 309 |
state.setdefault("session_id", uuid.uuid4().hex)
|
| 310 |
key_a = cfg["get_id"](row_a)
|
| 311 |
key_b = cfg["get_id"](row_b)
|
|
|
|
| 320 |
|
| 321 |
def _queue_decision(winner: str | None, state: dict):
|
| 322 |
assert state.get("session_id"), "Missing session_id: refusing to record vote"
|
| 323 |
+
|
| 324 |
+
pending = state.setdefault("pending", [])
|
| 325 |
+
pending.append({
|
| 326 |
"winner": winner,
|
| 327 |
"key_a": state["key_a"],
|
| 328 |
"key_b": state["key_b"],
|
|
|
|
| 335 |
"group": state["group"],
|
| 336 |
"pair_reason": state.get("pair_reason", ""),
|
| 337 |
"session_id": state["session_id"],
|
| 338 |
+
})
|
| 339 |
+
|
| 340 |
+
if len(pending) > 1:
|
| 341 |
+
VOTE_STORAGE.queue_row(pending.pop(0))
|
|
|
|
| 342 |
|
| 343 |
def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
|
|
|
|
| 344 |
if _normalize_submit_key(submit_key) != SUBMIT_KEY:
|
| 345 |
return _render_current(state, "Wrong submission key.")
|
| 346 |
+
|
| 347 |
_queue_decision(winner, state)
|
| 348 |
+
|
| 349 |
+
with _pool_lock:
|
| 350 |
+
match winner:
|
| 351 |
+
case "A":
|
| 352 |
+
_pool_df.iloc[_md5_to_idx[state["key_a"]], [WINS_LOC, VOTES_LOC]] += 1
|
| 353 |
+
_pool_df.iloc[_md5_to_idx[state["key_b"]], [LOSSES_LOC, VOTES_LOC]] += 1
|
| 354 |
+
case "B":
|
| 355 |
+
_pool_df.iloc[_md5_to_idx[state["key_b"]], [WINS_LOC, VOTES_LOC]] += 1
|
| 356 |
+
_pool_df.iloc[_md5_to_idx[state["key_a"]], [LOSSES_LOC, VOTES_LOC]] += 1
|
| 357 |
+
case None:
|
| 358 |
+
pass
|
| 359 |
+
case _:
|
| 360 |
+
raise AssertionError
|
| 361 |
+
|
| 362 |
return new_round(state["dataset"], state["rating_pref"], state)
|
| 363 |
|
| 364 |
def go_back(state: dict) -> tuple:
|
| 365 |
pending = state.setdefault("pending", [])
|
| 366 |
+
|
| 367 |
+
if pending:
|
| 368 |
+
last = pending.pop()
|
| 369 |
+
state.update(
|
| 370 |
+
dataset=last["dataset"],
|
| 371 |
+
rating_pref=last["rating_pref"],
|
| 372 |
+
key_a=last["key_a"],
|
| 373 |
+
key_b=last["key_b"],
|
| 374 |
+
id_a=last["id_a"],
|
| 375 |
+
id_b=last["id_b"],
|
| 376 |
+
url_a=last["url_a"],
|
| 377 |
+
url_b=last["url_b"],
|
| 378 |
+
group=last["group"],
|
| 379 |
+
pair_reason=last.get("pair_reason", ""),
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
with _pool_lock:
|
| 383 |
+
match last["winner"]:
|
| 384 |
+
case "A":
|
| 385 |
+
_pool_df.iloc[_md5_to_idx[state["key_a"]], [WINS_LOC, VOTES_LOC]] -= 1
|
| 386 |
+
_pool_df.iloc[_md5_to_idx[state["key_b"]], [LOSSES_LOC, VOTES_LOC]] -= 1
|
| 387 |
+
case "B":
|
| 388 |
+
_pool_df.iloc[_md5_to_idx[state["key_b"]], [WINS_LOC, VOTES_LOC]] -= 1
|
| 389 |
+
_pool_df.iloc[_md5_to_idx[state["key_a"]], [LOSSES_LOC, VOTES_LOC]] -= 1
|
| 390 |
+
case None:
|
| 391 |
+
pass
|
| 392 |
+
case _:
|
| 393 |
+
raise AssertionError
|
| 394 |
+
|
| 395 |
return _render_current(state)
|
| 396 |
|
| 397 |
# -- UI ---------------------------------------------------------------------
|
storage.py
CHANGED
|
@@ -30,79 +30,98 @@ class VoteStorage:
|
|
| 30 |
def __init__(self, mode: str, token: str | None = None):
|
| 31 |
assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
|
| 32 |
self.mode = mode
|
| 33 |
-
self._token = token
|
| 34 |
is_debug_mode = self.mode == "void"
|
|
|
|
| 35 |
self._flush_every = 3 if is_debug_mode else 50
|
| 36 |
self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
|
| 37 |
-
|
|
|
|
| 38 |
self._votes_buffer: list[dict] = []
|
| 39 |
-
|
|
|
|
| 40 |
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
| 41 |
self._flush_thread.start()
|
| 42 |
-
atexit.register(self.close)
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
def _empty_votes_df(self) -> pd.DataFrame:
|
| 48 |
return pd.DataFrame(columns=VOTE_COLUMNS)
|
| 49 |
|
| 50 |
def _upload_votes_batch(self, df: pd.DataFrame, commit_message: str):
|
| 51 |
-
assert set(VOTE_COLUMNS).issubset(df.columns), "Missing vote columns in upload batch"
|
| 52 |
if self.mode == "void":
|
| 53 |
-
_ = commit_message
|
| 54 |
return
|
|
|
|
| 55 |
ts = int(time.time())
|
| 56 |
shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
|
| 57 |
-
api = HfApi(token=self._hf_token())
|
| 58 |
-
with NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
|
| 59 |
-
tmp_path = tmp.name
|
| 60 |
-
try:
|
| 61 |
-
df[VOTE_COLUMNS].to_parquet(tmp_path, index=False)
|
| 62 |
-
api.upload_file(
|
| 63 |
-
path_or_fileobj=tmp_path,
|
| 64 |
-
path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
|
| 65 |
-
repo_id=VOTES_REPO_ID,
|
| 66 |
-
repo_type=VOTES_REPO_TYPE,
|
| 67 |
-
commit_message=commit_message,
|
| 68 |
-
)
|
| 69 |
-
finally:
|
| 70 |
-
if os.path.exists(tmp_path):
|
| 71 |
-
os.remove(tmp_path)
|
| 72 |
-
|
| 73 |
-
def _flush_votes(self, force: bool = False):
|
| 74 |
-
with self._votes_lock:
|
| 75 |
-
if not self._votes_buffer:
|
| 76 |
-
return
|
| 77 |
-
if not force and len(self._votes_buffer) < self._flush_every:
|
| 78 |
-
return
|
| 79 |
-
batch = list(self._votes_buffer)
|
| 80 |
-
self._votes_buffer.clear()
|
| 81 |
-
incoming = pd.DataFrame(batch)
|
| 82 |
-
for col in VOTE_COLUMNS:
|
| 83 |
-
if col not in incoming.columns:
|
| 84 |
-
incoming[col] = None
|
| 85 |
-
self._upload_votes_batch(incoming[VOTE_COLUMNS], commit_message=f"append {len(batch)} vote rows")
|
| 86 |
-
|
| 87 |
-
def _flush_loop(self):
|
| 88 |
-
while not self._stop_event.wait(self._flush_interval_sec):
|
| 89 |
-
self._flush_votes(force=True)
|
| 90 |
-
|
| 91 |
-
def close(self):
|
| 92 |
-
if self._stop_event.is_set():
|
| 93 |
-
return
|
| 94 |
-
self._stop_event.set()
|
| 95 |
-
self._flush_thread.join(timeout=1.0)
|
| 96 |
-
self._flush_votes(force=True)
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
id_a = int(state["id_a"])
|
| 100 |
id_b = int(state["id_b"])
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
vote_row = {
|
| 107 |
"vote_id": uuid.uuid4().hex,
|
| 108 |
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
|
|
@@ -115,6 +134,9 @@ class VoteStorage:
|
|
| 115 |
"group": state["group"],
|
| 116 |
"session_id": state["session_id"],
|
| 117 |
}
|
| 118 |
-
|
|
|
|
| 119 |
self._votes_buffer.append(vote_row)
|
| 120 |
-
|
|
|
|
|
|
|
|
|
| 30 |
def __init__(self, mode: str, token: str | None = None):
|
| 31 |
assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
|
| 32 |
self.mode = mode
|
|
|
|
| 33 |
is_debug_mode = self.mode == "void"
|
| 34 |
+
|
| 35 |
self._flush_every = 3 if is_debug_mode else 50
|
| 36 |
self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
|
| 37 |
+
|
| 38 |
+
self._shutdown = False
|
| 39 |
self._votes_buffer: list[dict] = []
|
| 40 |
+
|
| 41 |
+
self._flush_condition = threading.Condition(threading.Lock())
|
| 42 |
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
| 43 |
self._flush_thread.start()
|
|
|
|
| 44 |
|
| 45 |
+
self.hf_api = HfApi(token=token)
|
| 46 |
+
|
| 47 |
+
atexit.register(self.close)
|
| 48 |
|
| 49 |
def _empty_votes_df(self) -> pd.DataFrame:
|
| 50 |
return pd.DataFrame(columns=VOTE_COLUMNS)
|
| 51 |
|
| 52 |
def _upload_votes_batch(self, df: pd.DataFrame, commit_message: str):
|
|
|
|
| 53 |
if self.mode == "void":
|
|
|
|
| 54 |
return
|
| 55 |
+
|
| 56 |
ts = int(time.time())
|
| 57 |
shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
self.hf_api.upload_file(
|
| 60 |
+
path_or_fileobj=df.to_parquet(index=False),
|
| 61 |
+
path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
|
| 62 |
+
repo_id=VOTES_REPO_ID,
|
| 63 |
+
repo_type=VOTES_REPO_TYPE,
|
| 64 |
+
commit_message=commit_message,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def _flush_loop(self) -> None:
|
| 68 |
+
while True:
|
| 69 |
+
with self._flush_condition:
|
| 70 |
+
while True:
|
| 71 |
+
if self._shutdown:
|
| 72 |
+
# Flush last batch of votes.
|
| 73 |
+
if self._votes_buffer:
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
# Have enough votes to flush now.
|
| 79 |
+
if len(self._votes_buffer) >= self._flush_every:
|
| 80 |
+
break
|
| 81 |
+
|
| 82 |
+
# Wait for a notify to flush early or shutdown.
|
| 83 |
+
if not self._flush_condition.wait(self._flush_interval_sec):
|
| 84 |
+
# Interval elapsed. Flush if there is at least one vote.
|
| 85 |
+
if self._votes_buffer:
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
# Atomically take the batch of votes.
|
| 89 |
+
batch = self._votes_buffer
|
| 90 |
+
self._votes_buffer = []
|
| 91 |
+
|
| 92 |
+
assert batch
|
| 93 |
+
batch_df = pd.DataFrame(batch)
|
| 94 |
+
del batch
|
| 95 |
+
|
| 96 |
+
for col in VOTE_COLUMNS:
|
| 97 |
+
if col not in batch_df.columns:
|
| 98 |
+
batch_df[col] = None
|
| 99 |
+
|
| 100 |
+
batch_df = batch_df[VOTE_COLUMNS]
|
| 101 |
+
self._upload_votes_batch(batch_df, commit_message=f"upload {len(batch_df)} vote rows")
|
| 102 |
+
|
| 103 |
+
def close(self) -> None:
|
| 104 |
+
with self._flush_condition:
|
| 105 |
+
self._shutdown = True
|
| 106 |
+
self._flush_condition.notify()
|
| 107 |
+
|
| 108 |
+
self._flush_thread.join()
|
| 109 |
+
|
| 110 |
+
def queue_row(self, state: dict) -> None:
|
| 111 |
id_a = int(state["id_a"])
|
| 112 |
id_b = int(state["id_b"])
|
| 113 |
+
|
| 114 |
+
winner_md5: str | None
|
| 115 |
+
match state["winner"]:
|
| 116 |
+
case "A":
|
| 117 |
+
winner_md5 = state["key_a"]
|
| 118 |
+
case "B":
|
| 119 |
+
winner_md5 = state["key_b"]
|
| 120 |
+
case None:
|
| 121 |
+
winner_md5 = None
|
| 122 |
+
case _:
|
| 123 |
+
raise AssertionError
|
| 124 |
+
|
| 125 |
vote_row = {
|
| 126 |
"vote_id": uuid.uuid4().hex,
|
| 127 |
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
|
|
|
|
| 134 |
"group": state["group"],
|
| 135 |
"session_id": state["session_id"],
|
| 136 |
}
|
| 137 |
+
|
| 138 |
+
with self._flush_condition:
|
| 139 |
self._votes_buffer.append(vote_row)
|
| 140 |
+
|
| 141 |
+
if len(self._votes_buffer) == self._flush_every:
|
| 142 |
+
self._flush_condition.notify()
|