Spaces:
Running
Running
rht
#2
by RedHotTensors - opened
- app.py +241 -304
- compact_logs.py +41 -28
- explorer.py +0 -6
- stats_from_logs.py +32 -30
- storage.py +59 -100
app.py
CHANGED
|
@@ -5,9 +5,6 @@ import time
|
|
| 5 |
import uuid
|
| 6 |
import os
|
| 7 |
import html
|
| 8 |
-
import sys
|
| 9 |
-
|
| 10 |
-
from typing import Callable
|
| 11 |
|
| 12 |
import pandas as pd
|
| 13 |
from huggingface_hub import hf_hub_download
|
|
@@ -32,161 +29,110 @@ _pool_path = hf_hub_download(
|
|
| 32 |
token=RATINGS_APP_TOKEN
|
| 33 |
)
|
| 34 |
_pool_df = pd.read_parquet(_pool_path)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
WINS_LOC = _pool_df.columns.get_loc("wins")
|
| 38 |
-
LOSSES_LOC = _pool_df.columns.get_loc("losses")
|
| 39 |
-
TIES_LOC = _pool_df.columns.get_loc("ties")
|
| 40 |
-
VOTES_LOC = _pool_df.columns.get_loc("votes")
|
| 41 |
-
WINRATE_LOC = _pool_df.columns.get_loc("winrate")
|
| 42 |
-
|
| 43 |
-
_md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) }
|
| 44 |
-
|
| 45 |
-
_pool_lock = threading.Lock()
|
| 46 |
-
|
| 47 |
_stats_last_loaded_at = 0.0
|
|
|
|
| 48 |
_explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
|
| 49 |
|
| 50 |
-
def _load_stats() -> None:
|
| 51 |
-
VOTE_STORAGE.sync()
|
| 52 |
-
load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN)
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
stats.wins, stats.losses, stats.ties, stats.votes, stats.winrate
|
| 63 |
-
)
|
| 64 |
-
else:
|
| 65 |
-
n_missing += 1
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
token=RATINGS_APP_TOKEN,
|
| 81 |
-
)
|
| 82 |
-
validation_df = pd.read_parquet(
|
| 83 |
-
validation_set_path,
|
| 84 |
-
columns=["group", "id", "md5", "rating", "sample_url", "image_url"],
|
| 85 |
-
)
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str)
|
| 92 |
-
validation_df["md5"] = validation_df["md5"].astype(str)
|
| 93 |
-
|
| 94 |
-
global _explorer_df
|
| 95 |
-
_explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many")
|
| 96 |
-
|
| 97 |
-
def _stats_reloader() -> None:
|
| 98 |
-
while True:
|
| 99 |
-
time.sleep(STATS_RELOAD_S)
|
| 100 |
-
_load_stats()
|
| 101 |
-
|
| 102 |
-
_load_stats()
|
| 103 |
-
threading.Thread(target=_stats_reloader, daemon=True).start()
|
| 104 |
-
|
| 105 |
-
def _pick_from(df: pd.DataFrame, *, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None:
|
| 106 |
-
if len(df) < 2:
|
| 107 |
-
return None
|
| 108 |
-
|
| 109 |
-
sample = df.sample(2, weights=weights, replace=False)
|
| 110 |
-
return sample.iloc[0], sample.iloc[1], len(df)
|
| 111 |
-
|
| 112 |
-
def _pick_similar(
|
| 113 |
-
df: pd.DataFrame,
|
| 114 |
-
distance: Callable[[pd.DataFrame, pd.Series], pd.Series],
|
| 115 |
-
*,
|
| 116 |
-
weights: Callable[[pd.DataFrame], pd.Series] | None = None,
|
| 117 |
-
other_df: pd.DataFrame | None = None,
|
| 118 |
-
) -> tuple[pd.Series, pd.Series, int] | None:
|
| 119 |
-
if len(df) < 2:
|
| 120 |
-
return None
|
| 121 |
-
|
| 122 |
-
if other_df is None:
|
| 123 |
-
other_df = df
|
| 124 |
-
elif len(other_df) < 2:
|
| 125 |
-
return None
|
| 126 |
-
|
| 127 |
-
weight_vals: pd.Series | None = None
|
| 128 |
-
if weights is not None:
|
| 129 |
-
weight_vals = weights(df)
|
| 130 |
-
|
| 131 |
-
first = df.sample(weights=weight_vals).iloc[0]
|
| 132 |
-
weight_vals = 1.0 / (1.0 + distance(other_df, first))
|
| 133 |
-
|
| 134 |
-
while True:
|
| 135 |
-
other = other_df.sample(weights=weight_vals).iloc[0]
|
| 136 |
-
if other["md5"] != first["md5"]:
|
| 137 |
-
return first, other, len(df)
|
| 138 |
-
|
| 139 |
-
def _pool_fetch_pair(group: str) -> tuple[pd.Series, pd.Series, int, str]:
|
| 140 |
-
gdf = _pool_df[_pool_df["group"] == group]
|
| 141 |
-
voted = gdf[gdf["votes"] > 0]
|
| 142 |
-
votes = voted["votes"]
|
| 143 |
-
|
| 144 |
-
# Pair first-time winners.
|
| 145 |
-
picked = _pick_from(voted[(votes == 1) & (voted["wins"] == 1)])
|
| 146 |
-
if picked is not None:
|
| 147 |
-
return *picked, "new-winners"
|
| 148 |
|
| 149 |
-
# Pair first-time losers.
|
| 150 |
-
picked = _pick_from(voted[(votes == 1) & (voted["losses"] == 1)])
|
| 151 |
-
if picked is not None:
|
| 152 |
-
return *picked, "new-losers"
|
| 153 |
-
|
| 154 |
-
def record_distance(df: pd.DataFrame, pivot: pd.Series) -> pd.Series:
|
| 155 |
-
return (
|
| 156 |
-
(df["wins"] - pivot["wins"])**2 +
|
| 157 |
-
(df["losses"] - pivot["losses"])**2
|
| 158 |
-
)**0.75 # L2 is a bit too loose
|
| 159 |
-
|
| 160 |
-
# Link cliques to main network and break ties.
|
| 161 |
-
nonties = votes - voted["ties"]
|
| 162 |
-
picked = _pick_similar(
|
| 163 |
-
voted[(nonties == 0) | (votes == 2)],
|
| 164 |
-
record_distance,
|
| 165 |
-
other_df=voted[nonties > 3],
|
| 166 |
-
)
|
| 167 |
-
if picked is not None:
|
| 168 |
-
return *picked, "sparse"
|
| 169 |
-
|
| 170 |
-
# Introduce new images.
|
| 171 |
-
if len(voted) < 8 or random.random() < 0.33:
|
| 172 |
-
unvoted = gdf[gdf["votes"] == 0]
|
| 173 |
-
match len(unvoted):
|
| 174 |
-
case 0:
|
| 175 |
-
pass
|
| 176 |
-
case 1:
|
| 177 |
-
return unvoted.iloc[0], voted.iloc[0], 1, "new"
|
| 178 |
-
case _:
|
| 179 |
-
picked = _pick_from(unvoted)
|
| 180 |
-
assert picked is not None
|
| 181 |
-
return *picked, "new"
|
| 182 |
-
|
| 183 |
-
# Vote-weighted random sampling between similar winrates, slighlty biased against picking losers.
|
| 184 |
-
picked = _pick_similar(
|
| 185 |
-
voted, record_distance,
|
| 186 |
-
weights=lambda df: 1.0 / (df["votes"]**1.25 + 0.1 * df["losses"]),
|
| 187 |
-
)
|
| 188 |
-
assert picked is not None
|
| 189 |
-
return *picked, "fair-probe"
|
| 190 |
|
| 191 |
def _row_image_url(row) -> str:
|
| 192 |
sample_url = row.get("sample_url")
|
|
@@ -202,47 +148,89 @@ DATASETS: dict[str, dict] = {
|
|
| 202 |
"fetch_pair": _pool_fetch_pair,
|
| 203 |
"get_id": lambda row: row["md5"],
|
| 204 |
"get_image": _row_image_url,
|
| 205 |
-
"groups": sorted(_pool_df["group"].unique()),
|
| 206 |
},
|
| 207 |
}
|
| 208 |
DEFAULT_DATASET = list(DATASETS.keys())[0]
|
| 209 |
|
| 210 |
-
def
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
def _render_current(state: dict, submit_status: str = "") -> tuple:
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
-
can_go_back = bool(state.get("pending", ()))
|
| 222 |
-
pair_details = f"/ {state['group']} / {state.get('pair_reason', 'unknown')}"
|
| 223 |
|
| 224 |
-
return img_a_html, img_b_html, gr.Button(interactive=can_go_back), html.escape(pair_details), html.escape(submit_status), state
|
| 225 |
|
| 226 |
def _normalize_rating_pref(pref: str | None) -> str:
|
| 227 |
return pref if pref in ("safe", "all") else "safe"
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
|
|
|
| 231 |
submit_key = _normalize_submit_key(submit_key)
|
| 232 |
-
return rating_pref, submit_key,
|
|
|
|
| 233 |
|
| 234 |
-
def
|
| 235 |
-
|
|
|
|
| 236 |
|
| 237 |
-
def _on_image_height_change(image_height: str) -> tuple[str, str]:
|
| 238 |
-
return image_height, image_height
|
| 239 |
|
| 240 |
def _normalize_submit_key(submit_key: str | None) -> str:
|
| 241 |
-
return
|
|
|
|
| 242 |
|
| 243 |
def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame:
|
| 244 |
return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0])
|
| 245 |
|
|
|
|
| 246 |
def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame:
|
| 247 |
if rating_pref == "all":
|
| 248 |
rating_filtered = _explorer_df
|
|
@@ -252,10 +240,12 @@ def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str)
|
|
| 252 |
assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}"
|
| 253 |
return rating_filtered[rating_filtered["classifier"] == classifier_name]
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
| 259 |
filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
|
| 260 |
summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
|
| 261 |
filtered_explorer_df,
|
|
@@ -266,11 +256,13 @@ def _load_results(rating_pref: str, sort_mode: str, classifier_filter: str):
|
|
| 266 |
)
|
| 267 |
return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset
|
| 268 |
|
|
|
|
| 269 |
def _normalize_sort_mode(sort_mode: str | None) -> str:
|
| 270 |
if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"):
|
| 271 |
return sort_mode
|
| 272 |
return "Default"
|
| 273 |
|
|
|
|
| 274 |
def _normalize_classifier_filter(classifier_name: str | None) -> str:
|
| 275 |
if classifier_name in ALLOWED_CLASSIFIER_FILTERS:
|
| 276 |
return str(classifier_name)
|
|
@@ -278,22 +270,23 @@ def _normalize_classifier_filter(classifier_name: str | None) -> str:
|
|
| 278 |
|
| 279 |
# -- Gradio callbacks -------------------------------------------------------
|
| 280 |
|
| 281 |
-
def new_round(dataset_name: str,
|
| 282 |
-
if not groups:
|
| 283 |
-
return "", "", gr.skip(), "", "Please select at least one group.", state
|
| 284 |
-
|
| 285 |
cfg = DATASETS[dataset_name]
|
|
|
|
|
|
|
| 286 |
group = random.choice(groups)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
state.setdefault("session_id", uuid.uuid4().hex)
|
| 292 |
key_a = cfg["get_id"](row_a)
|
| 293 |
key_b = cfg["get_id"](row_b)
|
| 294 |
id_a = int(row_a["id"])
|
| 295 |
id_b = int(row_b["id"])
|
| 296 |
-
state.update(dataset=dataset_name, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason)
|
| 297 |
url_a = cfg["get_image"](row_a)
|
| 298 |
url_b = cfg["get_image"](row_b)
|
| 299 |
state["url_a"] = url_a
|
|
@@ -302,9 +295,8 @@ def new_round(dataset_name: str, groups: list[str], state: dict) -> tuple:
|
|
| 302 |
|
| 303 |
def _queue_decision(winner: str | None, state: dict):
|
| 304 |
assert state.get("session_id"), "Missing session_id: refusing to record vote"
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
pending.append({
|
| 308 |
"winner": winner,
|
| 309 |
"key_a": state["key_a"],
|
| 310 |
"key_b": state["key_b"],
|
|
@@ -313,86 +305,53 @@ def _queue_decision(winner: str | None, state: dict):
|
|
| 313 |
"url_a": state["url_a"],
|
| 314 |
"url_b": state["url_b"],
|
| 315 |
"dataset": state["dataset"],
|
|
|
|
| 316 |
"group": state["group"],
|
| 317 |
"pair_reason": state.get("pair_reason", ""),
|
| 318 |
"session_id": state["session_id"],
|
| 319 |
-
}
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
def _add_vote(idx: int, col_loc: int, delta: int = 1) -> None:
|
| 325 |
-
_pool_df.iloc[idx, [col_loc, VOTES_LOC]] += delta
|
| 326 |
-
|
| 327 |
-
wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]]
|
| 328 |
-
_pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1)
|
| 329 |
|
| 330 |
-
def vote(winner: str | None, state: dict,
|
|
|
|
| 331 |
if _normalize_submit_key(submit_key) != SUBMIT_KEY:
|
| 332 |
return _render_current(state, "Wrong submission key.")
|
| 333 |
-
|
| 334 |
-
if not groups:
|
| 335 |
-
return "", "", gr.skip(), "", "Please select at least one group.", state
|
| 336 |
-
|
| 337 |
_queue_decision(winner, state)
|
| 338 |
-
|
| 339 |
-
a_idx = _md5_to_idx[state["key_a"]]
|
| 340 |
-
b_idx = _md5_to_idx[state["key_b"]]
|
| 341 |
-
with _pool_lock:
|
| 342 |
-
match winner:
|
| 343 |
-
case "A":
|
| 344 |
-
_add_vote(a_idx, WINS_LOC)
|
| 345 |
-
_add_vote(b_idx, LOSSES_LOC)
|
| 346 |
-
case "B":
|
| 347 |
-
_add_vote(a_idx, LOSSES_LOC)
|
| 348 |
-
_add_vote(b_idx, WINS_LOC)
|
| 349 |
-
case None:
|
| 350 |
-
_add_vote(a_idx, TIES_LOC)
|
| 351 |
-
_add_vote(b_idx, TIES_LOC)
|
| 352 |
-
case _:
|
| 353 |
-
raise AssertionError
|
| 354 |
-
|
| 355 |
-
return new_round(state["dataset"], groups, state)
|
| 356 |
|
| 357 |
def go_back(state: dict) -> tuple:
|
| 358 |
pending = state.setdefault("pending", [])
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
case "B":
|
| 382 |
-
_add_vote(a_idx, LOSSES_LOC, -1)
|
| 383 |
-
_add_vote(b_idx, WINS_LOC, -1)
|
| 384 |
-
case None:
|
| 385 |
-
_add_vote(a_idx, TIES_LOC, -1)
|
| 386 |
-
_add_vote(b_idx, TIES_LOC, -1)
|
| 387 |
-
case _:
|
| 388 |
-
raise AssertionError
|
| 389 |
-
|
| 390 |
return _render_current(state)
|
| 391 |
|
| 392 |
# -- UI ---------------------------------------------------------------------
|
| 393 |
|
| 394 |
with gr.Blocks(
|
| 395 |
-
title="
|
| 396 |
head="""
|
| 397 |
<script>
|
| 398 |
const VOTE_COOLDOWN_MS = 1500;
|
|
@@ -475,15 +434,9 @@ with gr.Blocks(
|
|
| 475 |
} else if ((e.key === 'z' || e.key === 'Z') && (e.ctrlKey || e.metaKey) && ratingTabActive) {
|
| 476 |
e.preventDefault();
|
| 477 |
backButton?.click();
|
| 478 |
-
} else if (e.key === 'ArrowDown') {
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
backButton?.click();
|
| 482 |
-
}
|
| 483 |
-
if (resultsTabActive) {
|
| 484 |
-
e.preventDefault();
|
| 485 |
-
resultsLoadMoreButton?.click();
|
| 486 |
-
}
|
| 487 |
}
|
| 488 |
});
|
| 489 |
document.addEventListener('click', function (e) {
|
|
@@ -543,6 +496,7 @@ with gr.Blocks(
|
|
| 543 |
}
|
| 544 |
.rating-image-frame {
|
| 545 |
width: 100%;
|
|
|
|
| 546 |
border: 1px solid #e6e6e6;
|
| 547 |
border-radius: 8px;
|
| 548 |
background: #333;
|
|
@@ -604,23 +558,13 @@ with gr.Blocks(
|
|
| 604 |
object-fit: contain !important;
|
| 605 |
background: #1f2937;
|
| 606 |
}
|
| 607 |
-
a {
|
| 608 |
-
padding: 0 !important;
|
| 609 |
-
}
|
| 610 |
""",
|
| 611 |
-
fill_width=True,
|
| 612 |
) as demo:
|
| 613 |
state = gr.State({})
|
| 614 |
rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref")
|
| 615 |
submit_key_store = gr.BrowserState(default_value="", storage_key="submit_key")
|
| 616 |
results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode")
|
| 617 |
results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier")
|
| 618 |
-
image_height_store = gr.BrowserState(default_value=768, storage_key="image_height")
|
| 619 |
-
groups_store = gr.BrowserState(default_value=[
|
| 620 |
-
group
|
| 621 |
-
for group in DATASETS[DEFAULT_DATASET]["groups"]
|
| 622 |
-
if group.endswith("_safe")
|
| 623 |
-
], storage_key="groups")
|
| 624 |
|
| 625 |
with gr.Tabs():
|
| 626 |
with gr.Tab("Image Quality Rater"):
|
|
@@ -630,41 +574,38 @@ with gr.Blocks(
|
|
| 630 |
img_a = gr.HTML(elem_id="img-a")
|
| 631 |
img_b = gr.HTML(elem_id="img-b")
|
| 632 |
|
| 633 |
-
with gr.Row(
|
| 634 |
-
btn_a = gr.Button("
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
btn_skip = gr.Button("⬆️ Same Quality", elem_id="btn-skip")
|
| 638 |
-
btn_back_action = gr.Button("⬇️ Undo", elem_id="btn-back-action")
|
| 639 |
-
|
| 640 |
-
btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b")
|
| 641 |
|
| 642 |
with gr.Accordion("Settings", open=False):
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
image_height_slider = gr.Slider(
|
| 650 |
-
minimum=512, maximum=2048, step=16, precision=0,
|
| 651 |
-
label="Image Size",
|
| 652 |
)
|
| 653 |
submit_key_tb = gr.Textbox(
|
| 654 |
value="",
|
| 655 |
type="password",
|
| 656 |
-
label="Submit
|
| 657 |
elem_id="submit-key",
|
| 658 |
)
|
| 659 |
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
gr.
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
(
|
| 666 |
results_summary_md,
|
| 667 |
-
results_rating_dd,
|
| 668 |
results_sort_dd,
|
| 669 |
results_classifier_dd,
|
| 670 |
results_score_distribution_plot,
|
|
@@ -676,7 +617,7 @@ with gr.Blocks(
|
|
| 676 |
results_page_offset_state,
|
| 677 |
) = add_results_tab(_pool_df)
|
| 678 |
|
| 679 |
-
outputs = [img_a, img_b,
|
| 680 |
results_outputs = [
|
| 681 |
results_summary_md,
|
| 682 |
results_score_distribution_plot,
|
|
@@ -688,26 +629,22 @@ with gr.Blocks(
|
|
| 688 |
results_page_offset_state,
|
| 689 |
]
|
| 690 |
|
| 691 |
-
btn_a.click(fn=lambda s,
|
| 692 |
-
btn_b.click(fn=lambda s,
|
| 693 |
-
btn_skip.click(fn=lambda s,
|
| 694 |
btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
|
|
|
|
|
|
| 695 |
submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
|
| 696 |
-
|
| 697 |
-
image_height_slider.change(fn=_on_image_height_change, inputs=[image_height_slider], outputs=[image_height_store, image_height], queue=False, show_progress="hidden")
|
| 698 |
-
|
| 699 |
-
results_rating_dd.change(fn=_normalize_rating_pref, inputs=[results_rating_dd], outputs=[rating_pref_store], queue=False, show_progress="hidden")
|
| 700 |
-
results_rating_dd.change(fn=_load_results, inputs=[results_rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 701 |
results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden")
|
| 702 |
results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 703 |
results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden")
|
| 704 |
results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 705 |
-
|
| 706 |
-
demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store, groups_store], outputs=[results_rating_dd, submit_key_tb, image_height_slider, image_height, groups_select, *outputs], queue=False, show_progress="hidden")
|
| 707 |
demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 708 |
demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden")
|
| 709 |
demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden")
|
| 710 |
-
|
| 711 |
results_load_more_btn.click(
|
| 712 |
fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o),
|
| 713 |
inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state],
|
|
|
|
| 5 |
import uuid
|
| 6 |
import os
|
| 7 |
import html
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
import pandas as pd
|
| 10 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 29 |
token=RATINGS_APP_TOKEN
|
| 30 |
)
|
| 31 |
_pool_df = pd.read_parquet(_pool_path)
|
| 32 |
+
_pool_group_dfs = {g: gdf for g, gdf in _pool_df.groupby("group")}
|
| 33 |
+
_stats_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
_stats_last_loaded_at = 0.0
|
| 35 |
+
_stats_by_key: dict[str, tuple[int, int]] = {}
|
| 36 |
_explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
def _reload_stats_if_due(force: bool = False):
|
| 40 |
+
global _stats_last_loaded_at, _stats_by_key, _explorer_df
|
| 41 |
+
now = time.time()
|
| 42 |
+
if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
|
| 43 |
+
return
|
| 44 |
+
with _stats_lock:
|
| 45 |
+
now = time.time()
|
| 46 |
+
if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
|
| 47 |
+
return
|
| 48 |
+
_stats_by_key = load_stats_by_md5(
|
| 49 |
+
repo_id=POOL_REPO_ID,
|
| 50 |
+
token=RATINGS_APP_TOKEN,
|
| 51 |
+
)
|
| 52 |
+
classifier_scores_path = hf_hub_download(
|
| 53 |
+
repo_id=POOL_REPO_ID,
|
| 54 |
+
filename="classifier_scores.parquet",
|
| 55 |
+
repo_type="dataset",
|
| 56 |
+
token=RATINGS_APP_TOKEN,
|
| 57 |
+
)
|
| 58 |
+
validation_set_path = hf_hub_download(
|
| 59 |
+
repo_id=POOL_REPO_ID,
|
| 60 |
+
filename="validation_set.parquet",
|
| 61 |
+
repo_type="dataset",
|
| 62 |
+
token=RATINGS_APP_TOKEN,
|
| 63 |
+
)
|
| 64 |
+
validation_df = pd.read_parquet(
|
| 65 |
+
validation_set_path,
|
| 66 |
+
columns=["group", "id", "md5", "rating", "sample_url", "image_url"],
|
| 67 |
+
)
|
| 68 |
+
classifier_scores_df = pd.read_parquet(classifier_scores_path)
|
| 69 |
+
assert {"classifier", "md5", "classifier_score", "percentile"}.issubset(classifier_scores_df.columns), "classifier_scores.parquet missing expected columns"
|
| 70 |
+
classifier_scores_df = classifier_scores_df[["classifier", "md5", "classifier_score", "percentile"]]
|
| 71 |
+
classifier_scores_df["classifier"] = classifier_scores_df["classifier"].astype(str)
|
| 72 |
+
classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str)
|
| 73 |
+
validation_df["md5"] = validation_df["md5"].astype(str)
|
| 74 |
+
_explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many")
|
| 75 |
+
_stats_last_loaded_at = now
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
_reload_stats_if_due(force=True)
|
| 79 |
+
|
| 80 |
+
def _pool_fetch_pair(group_name: str) -> tuple:
|
| 81 |
+
gdf = _pool_group_dfs[group_name]
|
| 82 |
+
assert len(gdf) >= 2, f"Not enough rows for group: {group_name}"
|
| 83 |
+
md5_keys = gdf["md5"].astype(str)
|
| 84 |
+
wins = md5_keys.map(lambda k: _stats_by_key.get(k, (0, 0))[0])
|
| 85 |
+
losses = md5_keys.map(lambda k: _stats_by_key.get(k, (0, 0))[1])
|
| 86 |
+
|
| 87 |
+
def _pick_from_mask(mask: pd.Series):
|
| 88 |
+
candidate_df = gdf[mask]
|
| 89 |
+
if len(candidate_df) < 2:
|
| 90 |
+
return None
|
| 91 |
+
sample = candidate_df.sample(2, replace=False)
|
| 92 |
+
return sample.iloc[0], sample.iloc[1]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# 1) Repeat the lowest-margin edge participating in a cycle. (To prevent deadlock, stop if all margins are 4+.)
|
| 96 |
+
# a) If deadlocked on a cycle with 4+ images and no inner cycles, sample a random missing edge inside the cycle.
|
| 97 |
+
# 2) Pair images that both have wins only . (One of them will lose/tie. Stop when there is only one left.)
|
| 98 |
+
# 3) Pair images that both have losses only. (One of them will win/tie. Stop when there is only one left.)
|
| 99 |
+
# 4) Pair images with only 2 edges.
|
| 100 |
+
# 5) X% chance, re-sample an existing edge, inversely proportional to existing number of samples.
|
| 101 |
+
# 6) Y% chance, sample a random missing edge between images already sampled.
|
| 102 |
+
# 7) Pair an unsampled image with a random sampled image.
|
| 103 |
+
|
| 104 |
+
# 2) Pair images that currently have wins-only records.
|
| 105 |
+
picked = _pick_from_mask((wins > 0) & (losses == 0))
|
| 106 |
+
if picked is not None:
|
| 107 |
+
return picked[0], picked[1], "wins-only"
|
| 108 |
|
| 109 |
+
# 3) Pair images that currently have losses-only records.
|
| 110 |
+
picked = _pick_from_mask((wins == 0) & (losses > 0))
|
| 111 |
+
if picked is not None:
|
| 112 |
+
return picked[0], picked[1], "losses-only"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
+
# 4) Pair images that currently have exactly 2 total edges.
|
| 115 |
+
vote_totals = wins + losses
|
| 116 |
+
picked = _pick_from_mask(vote_totals == 2)
|
| 117 |
+
if picked is not None:
|
| 118 |
+
return picked[0], picked[1], "total_votes=2"
|
| 119 |
|
| 120 |
+
# 7) Prefer pairing an unsampled image with a random previously sampled image.
|
| 121 |
+
unsampled_mask = vote_totals == 0
|
| 122 |
+
if unsampled_mask.any():
|
| 123 |
+
unsampled_row = gdf[unsampled_mask].sample(1).iloc[0]
|
| 124 |
+
sampled_df = gdf[~unsampled_mask]
|
| 125 |
+
if len(sampled_df) >= 1:
|
| 126 |
+
sampled_row = sampled_df.sample(1).iloc[0]
|
| 127 |
+
else:
|
| 128 |
+
sampled_row = gdf.drop(index=unsampled_row.name).sample(1).iloc[0]
|
| 129 |
+
return unsampled_row, sampled_row, "unsampled+sampled"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
# 8) Safety fall back to low-vote weighted sampling.
|
| 132 |
+
sample_weights = 1.0 / (vote_totals + 1.0)
|
| 133 |
+
sample = gdf.sample(2, weights=sample_weights, replace=False)
|
| 134 |
+
return sample.iloc[0], sample.iloc[1], "low-vote"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def _row_image_url(row) -> str:
|
| 138 |
sample_url = row.get("sample_url")
|
|
|
|
| 148 |
"fetch_pair": _pool_fetch_pair,
|
| 149 |
"get_id": lambda row: row["md5"],
|
| 150 |
"get_image": _row_image_url,
|
| 151 |
+
"groups": {g: g for g in sorted(_pool_df["group"].unique())},
|
| 152 |
},
|
| 153 |
}
|
| 154 |
DEFAULT_DATASET = list(DATASETS.keys())[0]
|
| 155 |
|
| 156 |
+
def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
|
| 157 |
+
groups = list(cfg["groups"].keys())
|
| 158 |
+
if rating_pref == "all":
|
| 159 |
+
return groups
|
| 160 |
+
return [g for g in groups if g.endswith(f"_{rating_pref}")]
|
| 161 |
+
|
| 162 |
+
def _commit_oldest_pending(state: dict):
|
| 163 |
+
pending = state.setdefault("pending", [])
|
| 164 |
+
if len(pending) <= 1:
|
| 165 |
+
return
|
| 166 |
+
oldest = pending.pop(0)
|
| 167 |
+
if oldest.get("winner") in ("A", "B"):
|
| 168 |
+
_apply_local_stats_update(oldest["winner"], oldest["key_a"], oldest["key_b"])
|
| 169 |
+
threading.Thread(target=VOTE_STORAGE.append_vote_row, args=(oldest.copy(), oldest.get("winner")), daemon=True).start()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _apply_local_stats_update(winner: str, key_a: str, key_b: str):
|
| 173 |
+
assert winner in ("A", "B")
|
| 174 |
+
with _stats_lock:
|
| 175 |
+
wins_a, losses_a = _stats_by_key.get(str(key_a), (0, 0))
|
| 176 |
+
wins_b, losses_b = _stats_by_key.get(str(key_b), (0, 0))
|
| 177 |
+
if winner == "A":
|
| 178 |
+
_stats_by_key[str(key_a)] = (wins_a + 1, losses_a)
|
| 179 |
+
_stats_by_key[str(key_b)] = (wins_b, losses_b + 1)
|
| 180 |
+
else:
|
| 181 |
+
_stats_by_key[str(key_a)] = (wins_a, losses_a + 1)
|
| 182 |
+
_stats_by_key[str(key_b)] = (wins_b + 1, losses_b)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | None = None) -> str:
|
| 186 |
+
total_votes = wins + losses
|
| 187 |
+
url = f"https://e621.net/posts/{post_id}"
|
| 188 |
+
row = f"{url} | Times rated: {total_votes}"
|
| 189 |
+
return f"{label}: {row}" if label else row
|
| 190 |
|
| 191 |
def _render_current(state: dict, submit_status: str = "") -> tuple:
|
| 192 |
+
_reload_stats_if_due()
|
| 193 |
+
wins_a, losses_a = _stats_by_key.get(str(state["key_a"]), (0, 0))
|
| 194 |
+
wins_b, losses_b = _stats_by_key.get(str(state["key_b"]), (0, 0))
|
| 195 |
+
title_a = "Image A"
|
| 196 |
+
title_b = "Image B"
|
| 197 |
+
img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_a)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
|
| 198 |
+
img_b_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_b)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_b'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
|
| 199 |
+
link_a = _format_rating_post_row(state["id_a"], wins_a, losses_a, label="Image A")
|
| 200 |
+
link_b = _format_rating_post_row(state["id_b"], wins_b, losses_b, label="Image B")
|
| 201 |
+
can_go_back = bool(state.get("can_go_back"))
|
| 202 |
+
back_md = "[Undo Rating (Ctrl+z)](#back)" if can_go_back else "<span class='subtle-back-link-disabled'>Undo Rating (Ctrl+z)</span>"
|
| 203 |
+
group_md = f"<span class='subtle-note'>Group: {state['group']}</span>"
|
| 204 |
+
pair_reason = state.get("pair_reason", "")
|
| 205 |
+
pair_reason_md = f"<span class='subtle-note'>Pair: {html.escape(pair_reason)}</span>" if pair_reason else ""
|
| 206 |
+
status_md = f"<span class='submit-status-msg'>{html.escape(submit_status)}</span>" if submit_status else ""
|
| 207 |
+
return img_a_html, img_b_html, link_a, link_b, back_md, group_md, pair_reason_md, status_md, state
|
| 208 |
|
|
|
|
|
|
|
| 209 |
|
|
|
|
| 210 |
|
| 211 |
def _normalize_rating_pref(pref: str | None) -> str:
|
| 212 |
return pref if pref in ("safe", "all") else "safe"
|
| 213 |
|
| 214 |
+
|
| 215 |
+
def _initial_load(state: dict, pref: str | None, submit_key: str | None):
|
| 216 |
+
rating_pref = _normalize_rating_pref(pref)
|
| 217 |
submit_key = _normalize_submit_key(submit_key)
|
| 218 |
+
return rating_pref, submit_key, *new_round(DEFAULT_DATASET, rating_pref, state)
|
| 219 |
+
|
| 220 |
|
| 221 |
+
def _on_rating_change(rating_pref: str, state: dict):
|
| 222 |
+
rating_pref = _normalize_rating_pref(rating_pref)
|
| 223 |
+
return *new_round(DEFAULT_DATASET, rating_pref, state), rating_pref
|
| 224 |
|
|
|
|
|
|
|
| 225 |
|
| 226 |
def _normalize_submit_key(submit_key: str | None) -> str:
|
| 227 |
+
return submit_key or ""
|
| 228 |
+
|
| 229 |
|
| 230 |
def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame:
|
| 231 |
return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0])
|
| 232 |
|
| 233 |
+
|
| 234 |
def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame:
|
| 235 |
if rating_pref == "all":
|
| 236 |
rating_filtered = _explorer_df
|
|
|
|
| 240 |
assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}"
|
| 241 |
return rating_filtered[rating_filtered["classifier"] == classifier_name]
|
| 242 |
|
| 243 |
+
|
| 244 |
+
def _load_results(rating_pref_value: str, sort_mode_value: str, classifier_filter_value: str):
|
| 245 |
+
rating_pref = _normalize_rating_pref(rating_pref_value)
|
| 246 |
+
sort_mode = _normalize_sort_mode(sort_mode_value)
|
| 247 |
+
classifier_name = _normalize_classifier_filter(classifier_filter_value)
|
| 248 |
+
_reload_stats_if_due()
|
| 249 |
filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
|
| 250 |
summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
|
| 251 |
filtered_explorer_df,
|
|
|
|
| 256 |
)
|
| 257 |
return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset
|
| 258 |
|
| 259 |
+
|
| 260 |
def _normalize_sort_mode(sort_mode: str | None) -> str:
|
| 261 |
if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"):
|
| 262 |
return sort_mode
|
| 263 |
return "Default"
|
| 264 |
|
| 265 |
+
|
| 266 |
def _normalize_classifier_filter(classifier_name: str | None) -> str:
|
| 267 |
if classifier_name in ALLOWED_CLASSIFIER_FILTERS:
|
| 268 |
return str(classifier_name)
|
|
|
|
| 270 |
|
| 271 |
# -- Gradio callbacks -------------------------------------------------------
|
| 272 |
|
| 273 |
+
def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
|
|
|
|
|
|
|
|
|
|
| 274 |
cfg = DATASETS[dataset_name]
|
| 275 |
+
groups = _select_groups(cfg, rating_pref)
|
| 276 |
+
assert groups, f"No groups for rating preference: {rating_pref}"
|
| 277 |
group = random.choice(groups)
|
| 278 |
+
pair_data = cfg["fetch_pair"](cfg["groups"][group])
|
| 279 |
+
if len(pair_data) == 3:
|
| 280 |
+
row_a, row_b, pair_reason = pair_data
|
| 281 |
+
else:
|
| 282 |
+
row_a, row_b = pair_data
|
| 283 |
+
pair_reason = ""
|
| 284 |
state.setdefault("session_id", uuid.uuid4().hex)
|
| 285 |
key_a = cfg["get_id"](row_a)
|
| 286 |
key_b = cfg["get_id"](row_b)
|
| 287 |
id_a = int(row_a["id"])
|
| 288 |
id_b = int(row_b["id"])
|
| 289 |
+
state.update(dataset=dataset_name, rating_pref=rating_pref, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason)
|
| 290 |
url_a = cfg["get_image"](row_a)
|
| 291 |
url_b = cfg["get_image"](row_b)
|
| 292 |
state["url_a"] = url_a
|
|
|
|
| 295 |
|
| 296 |
def _queue_decision(winner: str | None, state: dict):
|
| 297 |
assert state.get("session_id"), "Missing session_id: refusing to record vote"
|
| 298 |
+
state.setdefault("pending", [])
|
| 299 |
+
decision = {
|
|
|
|
| 300 |
"winner": winner,
|
| 301 |
"key_a": state["key_a"],
|
| 302 |
"key_b": state["key_b"],
|
|
|
|
| 305 |
"url_a": state["url_a"],
|
| 306 |
"url_b": state["url_b"],
|
| 307 |
"dataset": state["dataset"],
|
| 308 |
+
"rating_pref": state["rating_pref"],
|
| 309 |
"group": state["group"],
|
| 310 |
"pair_reason": state.get("pair_reason", ""),
|
| 311 |
"session_id": state["session_id"],
|
| 312 |
+
}
|
| 313 |
+
state["pending"].append(decision)
|
| 314 |
+
state["last_decision"] = decision
|
| 315 |
+
state["can_go_back"] = True
|
| 316 |
+
_commit_oldest_pending(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
+
def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
|
| 319 |
+
assert winner in ("A", "B", None)
|
| 320 |
if _normalize_submit_key(submit_key) != SUBMIT_KEY:
|
| 321 |
return _render_current(state, "Wrong submission key.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
_queue_decision(winner, state)
|
| 323 |
+
return new_round(state["dataset"], state["rating_pref"], state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
def go_back(state: dict) -> tuple:
|
| 326 |
pending = state.setdefault("pending", [])
|
| 327 |
+
if not state.get("can_go_back"):
|
| 328 |
+
return _render_current(state)
|
| 329 |
+
last = state.get("last_decision")
|
| 330 |
+
if not last:
|
| 331 |
+
state["can_go_back"] = False
|
| 332 |
+
return _render_current(state)
|
| 333 |
+
if pending and pending[-1] == last:
|
| 334 |
+
pending.pop()
|
| 335 |
+
state["can_go_back"] = False
|
| 336 |
+
state["last_decision"] = None
|
| 337 |
+
state.update(
|
| 338 |
+
dataset=last["dataset"],
|
| 339 |
+
rating_pref=last["rating_pref"],
|
| 340 |
+
key_a=last["key_a"],
|
| 341 |
+
key_b=last["key_b"],
|
| 342 |
+
id_a=last["id_a"],
|
| 343 |
+
id_b=last["id_b"],
|
| 344 |
+
url_a=last["url_a"],
|
| 345 |
+
url_b=last["url_b"],
|
| 346 |
+
group=last["group"],
|
| 347 |
+
pair_reason=last.get("pair_reason", ""),
|
| 348 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
return _render_current(state)
|
| 350 |
|
| 351 |
# -- UI ---------------------------------------------------------------------
|
| 352 |
|
| 353 |
with gr.Blocks(
|
| 354 |
+
title="Image Rater",
|
| 355 |
head="""
|
| 356 |
<script>
|
| 357 |
const VOTE_COOLDOWN_MS = 1500;
|
|
|
|
| 434 |
} else if ((e.key === 'z' || e.key === 'Z') && (e.ctrlKey || e.metaKey) && ratingTabActive) {
|
| 435 |
e.preventDefault();
|
| 436 |
backButton?.click();
|
| 437 |
+
} else if (e.key === 'ArrowDown' && resultsTabActive) {
|
| 438 |
+
e.preventDefault();
|
| 439 |
+
resultsLoadMoreButton?.click();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
}
|
| 441 |
});
|
| 442 |
document.addEventListener('click', function (e) {
|
|
|
|
| 496 |
}
|
| 497 |
.rating-image-frame {
|
| 498 |
width: 100%;
|
| 499 |
+
height: 512px;
|
| 500 |
border: 1px solid #e6e6e6;
|
| 501 |
border-radius: 8px;
|
| 502 |
background: #333;
|
|
|
|
| 558 |
object-fit: contain !important;
|
| 559 |
background: #1f2937;
|
| 560 |
}
|
|
|
|
|
|
|
|
|
|
| 561 |
""",
|
|
|
|
| 562 |
) as demo:
|
| 563 |
state = gr.State({})
|
| 564 |
rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref")
|
| 565 |
submit_key_store = gr.BrowserState(default_value="", storage_key="submit_key")
|
| 566 |
results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode")
|
| 567 |
results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
|
| 569 |
with gr.Tabs():
|
| 570 |
with gr.Tab("Image Quality Rater"):
|
|
|
|
| 574 |
img_a = gr.HTML(elem_id="img-a")
|
| 575 |
img_b = gr.HTML(elem_id="img-b")
|
| 576 |
|
| 577 |
+
with gr.Row():
|
| 578 |
+
btn_a = gr.Button("👍 Prefer A", variant="primary", elem_id="btn-vote-a")
|
| 579 |
+
btn_skip = gr.Button("Same quality", elem_id="btn-skip")
|
| 580 |
+
btn_b = gr.Button("👍 Prefer B", variant="primary", elem_id="btn-vote-b")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
with gr.Accordion("Settings", open=False):
|
| 583 |
+
gr.Markdown("<span style='color:#888;font-size:0.9em;'>Advanced options</span>")
|
| 584 |
+
rating_dd = gr.Dropdown(
|
| 585 |
+
choices=["safe", "all"],
|
| 586 |
+
value="safe",
|
| 587 |
+
label="Rating",
|
| 588 |
+
elem_id="rating-pref",
|
|
|
|
|
|
|
|
|
|
| 589 |
)
|
| 590 |
submit_key_tb = gr.Textbox(
|
| 591 |
value="",
|
| 592 |
type="password",
|
| 593 |
+
label="Submit key",
|
| 594 |
elem_id="submit-key",
|
| 595 |
)
|
| 596 |
|
| 597 |
+
link_a = gr.Markdown(label="Image A link")
|
| 598 |
+
link_b = gr.Markdown(label="Image B link")
|
| 599 |
+
back_link = gr.Markdown(elem_classes=["subtle-back-link-wrap"])
|
| 600 |
+
btn_back_action = gr.Button("Undo Rating (Ctrl+z)", elem_id="btn-back-action", elem_classes=["hidden-action-btn"])
|
| 601 |
+
details_md = gr.Markdown()
|
| 602 |
+
pair_reason_md = gr.Markdown()
|
| 603 |
+
submit_status_md = gr.Markdown(elem_id="submit-status")
|
| 604 |
+
gr.Markdown("<span class='subtle-note'>Dataset: <a href='https://huggingface.co/datasets/taigasan/e6-visual-ratings' target='_blank' rel='noopener noreferrer'>taigasan/e6-visual-ratings</a></span>")
|
| 605 |
+
gr.Markdown("<span class='subtle-note'>Keyboard Shortcuts: ⬅️ vote A, ⬆️ same quality, ➡️ vote B, Ctrl+z undo rating</span>")
|
| 606 |
|
| 607 |
(
|
| 608 |
results_summary_md,
|
|
|
|
| 609 |
results_sort_dd,
|
| 610 |
results_classifier_dd,
|
| 611 |
results_score_distribution_plot,
|
|
|
|
| 617 |
results_page_offset_state,
|
| 618 |
) = add_results_tab(_pool_df)
|
| 619 |
|
| 620 |
+
outputs = [img_a, img_b, link_a, link_b, back_link, details_md, pair_reason_md, submit_status_md, state]
|
| 621 |
results_outputs = [
|
| 622 |
results_summary_md,
|
| 623 |
results_score_distribution_plot,
|
|
|
|
| 629 |
results_page_offset_state,
|
| 630 |
]
|
| 631 |
|
| 632 |
+
btn_a.click(fn=lambda s, k: vote("A", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 633 |
+
btn_b.click(fn=lambda s, k: vote("B", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 634 |
+
btn_skip.click(fn=lambda s, k: vote(None, s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 635 |
btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
| 636 |
+
rating_dd.change(fn=_on_rating_change, inputs=[rating_dd, state], outputs=[*outputs, rating_pref_store], queue=False, show_progress="hidden")
|
| 637 |
+
submit_key_tb.input(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
|
| 638 |
submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
|
| 639 |
+
rating_dd.change(fn=_load_results, inputs=[rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden")
|
| 641 |
results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 642 |
results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden")
|
| 643 |
results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 644 |
+
demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store], outputs=[rating_dd, submit_key_tb, *outputs], queue=False, show_progress="hidden")
|
|
|
|
| 645 |
demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 646 |
demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden")
|
| 647 |
demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden")
|
|
|
|
| 648 |
results_load_more_btn.click(
|
| 649 |
fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o),
|
| 650 |
inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state],
|
compact_logs.py
CHANGED
|
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|
| 4 |
import os
|
| 5 |
import time
|
| 6 |
import uuid
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import pandas as pd
|
| 9 |
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, hf_hub_download
|
|
@@ -13,6 +15,7 @@ VOTES_REPO_TYPE = "dataset"
|
|
| 13 |
VOTES_LOG_SUBDIR = "ratings_log"
|
| 14 |
RATINGS_APP_TOKEN_ENV = "RATINGS_APP_TOKEN"
|
| 15 |
|
|
|
|
| 16 |
def _list_vote_shards(api: HfApi) -> list[str]:
|
| 17 |
files = api.list_repo_files(repo_id=VOTES_REPO_ID, repo_type=VOTES_REPO_TYPE)
|
| 18 |
shard_prefix = f"{VOTES_LOG_SUBDIR}/votes_"
|
|
@@ -23,19 +26,21 @@ def _list_vote_shards(api: HfApi) -> list[str]:
|
|
| 23 |
and f.endswith(".parquet")
|
| 24 |
)
|
| 25 |
|
|
|
|
| 26 |
def _new_compacted_shard_path() -> str:
|
| 27 |
ts = int(time.time())
|
| 28 |
return f"{VOTES_LOG_SUBDIR}/votes_{ts}_{uuid.uuid4().hex}.parquet"
|
| 29 |
|
| 30 |
-
|
|
|
|
| 31 |
token = os.getenv(RATINGS_APP_TOKEN_ENV)
|
| 32 |
api = HfApi(token=token)
|
| 33 |
|
| 34 |
shards = _list_vote_shards(api)
|
| 35 |
-
if
|
| 36 |
-
|
| 37 |
|
| 38 |
-
frames
|
| 39 |
for shard in shards:
|
| 40 |
shard_local = hf_hub_download(
|
| 41 |
repo_id=VOTES_REPO_ID,
|
|
@@ -43,39 +48,47 @@ def compact_votes() -> tuple[int, int, str] | None:
|
|
| 43 |
repo_type=VOTES_REPO_TYPE,
|
| 44 |
token=token,
|
| 45 |
)
|
| 46 |
-
|
|
|
|
| 47 |
|
|
|
|
| 48 |
combined = pd.concat(frames, ignore_index=True, sort=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
return len(shards),
|
| 66 |
|
| 67 |
-
def _main() -> None:
|
| 68 |
-
result = compact_votes()
|
| 69 |
-
if result is None:
|
| 70 |
-
print(f"Nothing to compact.")
|
| 71 |
-
return
|
| 72 |
|
| 73 |
-
|
|
|
|
| 74 |
print(
|
| 75 |
-
f"Compacted {shard_count}
|
| 76 |
f"{VOTES_REPO_ID}/{compacted_path} "
|
| 77 |
f"with {row_count} rows."
|
| 78 |
)
|
| 79 |
|
|
|
|
| 80 |
if __name__ == "__main__":
|
| 81 |
-
|
|
|
|
| 4 |
import os
|
| 5 |
import time
|
| 6 |
import uuid
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tempfile import NamedTemporaryFile
|
| 9 |
|
| 10 |
import pandas as pd
|
| 11 |
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, hf_hub_download
|
|
|
|
| 15 |
VOTES_LOG_SUBDIR = "ratings_log"
|
| 16 |
RATINGS_APP_TOKEN_ENV = "RATINGS_APP_TOKEN"
|
| 17 |
|
| 18 |
+
|
| 19 |
def _list_vote_shards(api: HfApi) -> list[str]:
|
| 20 |
files = api.list_repo_files(repo_id=VOTES_REPO_ID, repo_type=VOTES_REPO_TYPE)
|
| 21 |
shard_prefix = f"{VOTES_LOG_SUBDIR}/votes_"
|
|
|
|
| 26 |
and f.endswith(".parquet")
|
| 27 |
)
|
| 28 |
|
| 29 |
+
|
| 30 |
def _new_compacted_shard_path() -> str:
|
| 31 |
ts = int(time.time())
|
| 32 |
return f"{VOTES_LOG_SUBDIR}/votes_{ts}_{uuid.uuid4().hex}.parquet"
|
| 33 |
|
| 34 |
+
|
| 35 |
+
def compact_votes() -> tuple[int, int, str]:
|
| 36 |
token = os.getenv(RATINGS_APP_TOKEN_ENV)
|
| 37 |
api = HfApi(token=token)
|
| 38 |
|
| 39 |
shards = _list_vote_shards(api)
|
| 40 |
+
if not shards:
|
| 41 |
+
raise FileNotFoundError(f"No vote shards found in {VOTES_REPO_ID}/{VOTES_LOG_SUBDIR}")
|
| 42 |
|
| 43 |
+
frames = []
|
| 44 |
for shard in shards:
|
| 45 |
shard_local = hf_hub_download(
|
| 46 |
repo_id=VOTES_REPO_ID,
|
|
|
|
| 48 |
repo_type=VOTES_REPO_TYPE,
|
| 49 |
token=token,
|
| 50 |
)
|
| 51 |
+
frame = pd.read_parquet(shard_local)
|
| 52 |
+
frames.append(frame)
|
| 53 |
|
| 54 |
+
input_row_count = sum(len(frame) for frame in frames)
|
| 55 |
combined = pd.concat(frames, ignore_index=True, sort=False)
|
| 56 |
+
output_row_count = int(len(combined))
|
| 57 |
+
if output_row_count != input_row_count:
|
| 58 |
+
raise RuntimeError(
|
| 59 |
+
f"Refusing to commit: row mismatch during compaction "
|
| 60 |
+
f"({input_row_count} -> {output_row_count})."
|
| 61 |
+
)
|
| 62 |
|
| 63 |
+
with NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
|
| 64 |
+
tmp_path = Path(tmp.name)
|
| 65 |
+
try:
|
| 66 |
+
combined.to_parquet(tmp_path, index=False)
|
| 67 |
+
compacted_path = _new_compacted_shard_path()
|
| 68 |
+
operations = [
|
| 69 |
+
CommitOperationAdd(path_or_fileobj=str(tmp_path), path_in_repo=compacted_path),
|
| 70 |
+
*[CommitOperationDelete(path_in_repo=shard) for shard in shards],
|
| 71 |
+
]
|
| 72 |
+
api.create_commit(
|
| 73 |
+
repo_id=VOTES_REPO_ID,
|
| 74 |
+
repo_type=VOTES_REPO_TYPE,
|
| 75 |
+
commit_message=f"compact {len(shards)} vote shard(s)",
|
| 76 |
+
operations=operations,
|
| 77 |
+
)
|
| 78 |
+
finally:
|
| 79 |
+
tmp_path.unlink(missing_ok=True)
|
| 80 |
|
| 81 |
+
return len(shards), output_row_count, compacted_path
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
def main() -> None:
|
| 85 |
+
shard_count, row_count, compacted_path = compact_votes()
|
| 86 |
print(
|
| 87 |
+
f"Compacted {shard_count} shard(s) into "
|
| 88 |
f"{VOTES_REPO_ID}/{compacted_path} "
|
| 89 |
f"with {row_count} rows."
|
| 90 |
)
|
| 91 |
|
| 92 |
+
|
| 93 |
if __name__ == "__main__":
|
| 94 |
+
main()
|
explorer.py
CHANGED
|
@@ -208,11 +208,6 @@ def add_results_tab(pool_df: pd.DataFrame):
|
|
| 208 |
results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more")
|
| 209 |
selected_image_md = gr.Markdown("Click an image to reveal its ID and link.")
|
| 210 |
results_score_distribution_plot = gr.Plot(label="Classifier score distribution")
|
| 211 |
-
results_rating_dd = gr.Dropdown(
|
| 212 |
-
choices=["safe", "all"],
|
| 213 |
-
value="safe",
|
| 214 |
-
label="Rating",
|
| 215 |
-
)
|
| 216 |
results_sort_dd = gr.Dropdown(
|
| 217 |
choices=SORT_MODES,
|
| 218 |
value="Default",
|
|
@@ -230,7 +225,6 @@ def add_results_tab(pool_df: pd.DataFrame):
|
|
| 230 |
results_page_offset_state = gr.State(0)
|
| 231 |
return (
|
| 232 |
results_summary_md,
|
| 233 |
-
results_rating_dd,
|
| 234 |
results_sort_dd,
|
| 235 |
results_classifier_dd,
|
| 236 |
results_score_distribution_plot,
|
|
|
|
| 208 |
results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more")
|
| 209 |
selected_image_md = gr.Markdown("Click an image to reveal its ID and link.")
|
| 210 |
results_score_distribution_plot = gr.Plot(label="Classifier score distribution")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
results_sort_dd = gr.Dropdown(
|
| 212 |
choices=SORT_MODES,
|
| 213 |
value="Default",
|
|
|
|
| 225 |
results_page_offset_state = gr.State(0)
|
| 226 |
return (
|
| 227 |
results_summary_md,
|
|
|
|
| 228 |
results_sort_dd,
|
| 229 |
results_classifier_dd,
|
| 230 |
results_score_distribution_plot,
|
stats_from_logs.py
CHANGED
|
@@ -1,24 +1,11 @@
|
|
| 1 |
-
from collections import
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
|
| 4 |
import pandas as pd
|
| 5 |
from huggingface_hub import HfApi, hf_hub_download
|
| 6 |
|
| 7 |
-
VOTES_LOG_SUBDIR = "ratings_log"
|
| 8 |
-
|
| 9 |
-
@dataclass(slots=True)
|
| 10 |
-
class Stats:
|
| 11 |
-
wins: int = 0
|
| 12 |
-
losses: int = 0
|
| 13 |
-
ties: int = 0
|
| 14 |
|
| 15 |
-
|
| 16 |
-
def votes(self) -> int:
|
| 17 |
-
return self.wins + self.losses + self.ties
|
| 18 |
|
| 19 |
-
@property
|
| 20 |
-
def winrate(self) -> float:
|
| 21 |
-
return (self.wins + self.ties * 0.5) / max(self.votes, 1)
|
| 22 |
|
| 23 |
def _list_remote_log_files(repo_id: str, token: str | None) -> list[str]:
|
| 24 |
api = HfApi(token=token)
|
|
@@ -29,8 +16,32 @@ def _list_remote_log_files(repo_id: str, token: str | None) -> list[str]:
|
|
| 29 |
if f.startswith(f"{VOTES_LOG_SUBDIR}/") and f.endswith(".parquet")
|
| 30 |
)
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
for path_in_repo in _list_remote_log_files(repo_id, token):
|
| 36 |
local_path = hf_hub_download(
|
|
@@ -39,17 +50,8 @@ def load_stats_by_md5(*, repo_id: str, token: str | None) -> dict[str, Stats]:
|
|
| 39 |
repo_type="dataset",
|
| 40 |
token=token,
|
| 41 |
)
|
| 42 |
-
|
| 43 |
df = pd.read_parquet(local_path, columns=["md5a", "md5b", "winner_md5"])
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
elif winner_md5 == md5b:
|
| 49 |
-
stats[md5b].wins += 1
|
| 50 |
-
stats[md5a].losses += 1
|
| 51 |
-
else:
|
| 52 |
-
stats[md5a].ties += 1
|
| 53 |
-
stats[md5b].ties += 1
|
| 54 |
-
|
| 55 |
-
return stats
|
|
|
|
| 1 |
+
from collections import Counter
|
|
|
|
| 2 |
|
| 3 |
import pandas as pd
|
| 4 |
from huggingface_hub import HfApi, hf_hub_download
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
VOTES_LOG_SUBDIR = "ratings_log"
|
|
|
|
|
|
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def _list_remote_log_files(repo_id: str, token: str | None) -> list[str]:
|
| 11 |
api = HfApi(token=token)
|
|
|
|
| 16 |
if f.startswith(f"{VOTES_LOG_SUBDIR}/") and f.endswith(".parquet")
|
| 17 |
)
|
| 18 |
|
| 19 |
+
|
| 20 |
+
def _accumulate_stats_from_df(df: pd.DataFrame, wins_counter: Counter, losses_counter: Counter):
|
| 21 |
+
if df.empty:
|
| 22 |
+
return
|
| 23 |
+
valid = df[df["winner_md5"].notna()].copy()
|
| 24 |
+
if valid.empty:
|
| 25 |
+
return
|
| 26 |
+
valid["md5a"] = valid["md5a"].astype(str)
|
| 27 |
+
valid["md5b"] = valid["md5b"].astype(str)
|
| 28 |
+
valid["winner_md5"] = valid["winner_md5"].astype(str)
|
| 29 |
+
|
| 30 |
+
a_won_mask = valid["winner_md5"] == valid["md5a"]
|
| 31 |
+
b_won_mask = valid["winner_md5"] == valid["md5b"]
|
| 32 |
+
|
| 33 |
+
winner_keys = pd.concat([valid.loc[a_won_mask, "md5a"], valid.loc[b_won_mask, "md5b"]], ignore_index=True)
|
| 34 |
+
loser_keys = pd.concat([valid.loc[a_won_mask, "md5b"], valid.loc[b_won_mask, "md5a"]], ignore_index=True)
|
| 35 |
+
|
| 36 |
+
for key, count in winner_keys.value_counts().items():
|
| 37 |
+
wins_counter[str(key)] += int(count)
|
| 38 |
+
for key, count in loser_keys.value_counts().items():
|
| 39 |
+
losses_counter[str(key)] += int(count)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_stats_by_md5(*, repo_id: str, token: str | None) -> dict[str, tuple[int, int]]:
|
| 43 |
+
wins_counter: Counter[str] = Counter()
|
| 44 |
+
losses_counter: Counter[str] = Counter()
|
| 45 |
|
| 46 |
for path_in_repo in _list_remote_log_files(repo_id, token):
|
| 47 |
local_path = hf_hub_download(
|
|
|
|
| 50 |
repo_type="dataset",
|
| 51 |
token=token,
|
| 52 |
)
|
|
|
|
| 53 |
df = pd.read_parquet(local_path, columns=["md5a", "md5b", "winner_md5"])
|
| 54 |
+
_accumulate_stats_from_df(df, wins_counter, losses_counter)
|
| 55 |
+
|
| 56 |
+
all_keys = set(wins_counter) | set(losses_counter)
|
| 57 |
+
return {k: (int(wins_counter.get(k, 0)), int(losses_counter.get(k, 0))) for k in all_keys}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
storage.py
CHANGED
|
@@ -27,120 +27,82 @@ VOTE_COLUMNS = [
|
|
| 27 |
|
| 28 |
|
| 29 |
class VoteStorage:
|
| 30 |
-
def __init__(self, mode: str, token: str | None = None)
|
| 31 |
assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
|
| 32 |
self.mode = mode
|
|
|
|
| 33 |
is_debug_mode = self.mode == "void"
|
| 34 |
-
|
| 35 |
self._flush_every = 3 if is_debug_mode else 50
|
| 36 |
self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
|
| 37 |
-
self.
|
| 38 |
-
|
| 39 |
-
self._flush_condition = threading.Condition(threading.Lock())
|
| 40 |
-
self._sync_event = threading.Event()
|
| 41 |
-
self._sync_lock = threading.Lock()
|
| 42 |
self._votes_buffer: list[dict] = []
|
| 43 |
-
self.
|
| 44 |
-
|
| 45 |
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
| 46 |
self._flush_thread.start()
|
| 47 |
-
|
| 48 |
atexit.register(self.close)
|
| 49 |
|
| 50 |
-
def
|
| 51 |
-
|
| 52 |
-
if self.mode == "void":
|
| 53 |
-
return
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
for col in VOTE_COLUMNS:
|
| 58 |
-
if col not in df.columns:
|
| 59 |
-
df[col] = None
|
| 60 |
-
|
| 61 |
-
df = df[VOTE_COLUMNS]
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
ts = int(time.time())
|
| 64 |
shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
repo_id=VOTES_REPO_ID,
|
| 70 |
-
repo_type=VOTES_REPO_TYPE,
|
| 71 |
-
commit_message=f"upload {len(df)} vote rows",
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
def _flush_loop(self) -> None:
|
| 75 |
-
while True:
|
| 76 |
-
with self._flush_condition:
|
| 77 |
-
while True:
|
| 78 |
-
# Forced sync.
|
| 79 |
-
if not self._sync_event.is_set():
|
| 80 |
-
if self._votes_buffer:
|
| 81 |
-
break
|
| 82 |
-
|
| 83 |
-
self._sync_event.set()
|
| 84 |
-
|
| 85 |
-
# Shutdown wanted.
|
| 86 |
-
if self._shutdown:
|
| 87 |
-
if self._votes_buffer:
|
| 88 |
-
break
|
| 89 |
-
|
| 90 |
-
return
|
| 91 |
-
|
| 92 |
-
# Have enough votes to flush now.
|
| 93 |
-
if len(self._votes_buffer) >= self._flush_every:
|
| 94 |
-
break
|
| 95 |
-
|
| 96 |
-
# Wait for a notify to flush early or shutdown.
|
| 97 |
-
if not self._flush_condition.wait(self._flush_interval_sec):
|
| 98 |
-
# Interval elapsed. Flush if there is at least one vote.
|
| 99 |
-
if self._votes_buffer:
|
| 100 |
-
break
|
| 101 |
-
|
| 102 |
-
# Atomically take the batch of votes.
|
| 103 |
-
batch = self._votes_buffer
|
| 104 |
-
self._votes_buffer = []
|
| 105 |
-
|
| 106 |
-
self._upload_votes_batch(batch)
|
| 107 |
-
|
| 108 |
-
def sync(self) -> None:
|
| 109 |
-
with self._sync_lock:
|
| 110 |
-
with self._flush_condition:
|
| 111 |
-
is_shutdown = self._shutdown
|
| 112 |
-
if not is_shutdown:
|
| 113 |
-
self._sync_event.clear()
|
| 114 |
-
self._flush_condition.notify()
|
| 115 |
-
|
| 116 |
-
if not is_shutdown:
|
| 117 |
-
self._sync_event.wait()
|
| 118 |
-
|
| 119 |
-
if is_shutdown:
|
| 120 |
-
self._flush_thread.join()
|
| 121 |
-
|
| 122 |
-
def close(self) -> None:
|
| 123 |
-
with self._flush_condition:
|
| 124 |
-
self._shutdown = True
|
| 125 |
-
self._flush_condition.notify()
|
| 126 |
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
def
|
| 130 |
id_a = int(state["id_a"])
|
| 131 |
id_b = int(state["id_b"])
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
case "B":
|
| 138 |
-
winner_md5 = state["key_b"]
|
| 139 |
-
case None:
|
| 140 |
-
winner_md5 = None
|
| 141 |
-
case _:
|
| 142 |
-
raise AssertionError
|
| 143 |
-
|
| 144 |
vote_row = {
|
| 145 |
"vote_id": uuid.uuid4().hex,
|
| 146 |
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
|
|
@@ -153,9 +115,6 @@ class VoteStorage:
|
|
| 153 |
"group": state["group"],
|
| 154 |
"session_id": state["session_id"],
|
| 155 |
}
|
| 156 |
-
|
| 157 |
-
with self._flush_condition:
|
| 158 |
self._votes_buffer.append(vote_row)
|
| 159 |
-
|
| 160 |
-
if len(self._votes_buffer) == self._flush_every:
|
| 161 |
-
self._flush_condition.notify()
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class VoteStorage:
|
| 30 |
+
def __init__(self, mode: str, token: str | None = None):
|
| 31 |
assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
|
| 32 |
self.mode = mode
|
| 33 |
+
self._token = token
|
| 34 |
is_debug_mode = self.mode == "void"
|
|
|
|
| 35 |
self._flush_every = 3 if is_debug_mode else 50
|
| 36 |
self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
|
| 37 |
+
self._votes_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
self._votes_buffer: list[dict] = []
|
| 39 |
+
self._stop_event = threading.Event()
|
|
|
|
| 40 |
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
| 41 |
self._flush_thread.start()
|
|
|
|
| 42 |
atexit.register(self.close)
|
| 43 |
|
| 44 |
+
def _hf_token(self) -> str | None:
|
| 45 |
+
return self._token
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
def _empty_votes_df(self) -> pd.DataFrame:
|
| 48 |
+
return pd.DataFrame(columns=VOTE_COLUMNS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
def _upload_votes_batch(self, df: pd.DataFrame, commit_message: str):
|
| 51 |
+
assert set(VOTE_COLUMNS).issubset(df.columns), "Missing vote columns in upload batch"
|
| 52 |
+
if self.mode == "void":
|
| 53 |
+
_ = commit_message
|
| 54 |
+
return
|
| 55 |
ts = int(time.time())
|
| 56 |
shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
|
| 57 |
+
api = HfApi(token=self._hf_token())
|
| 58 |
+
with NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
|
| 59 |
+
tmp_path = tmp.name
|
| 60 |
+
try:
|
| 61 |
+
df[VOTE_COLUMNS].to_parquet(tmp_path, index=False)
|
| 62 |
+
api.upload_file(
|
| 63 |
+
path_or_fileobj=tmp_path,
|
| 64 |
+
path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
|
| 65 |
+
repo_id=VOTES_REPO_ID,
|
| 66 |
+
repo_type=VOTES_REPO_TYPE,
|
| 67 |
+
commit_message=commit_message,
|
| 68 |
+
)
|
| 69 |
+
finally:
|
| 70 |
+
if os.path.exists(tmp_path):
|
| 71 |
+
os.remove(tmp_path)
|
| 72 |
+
|
| 73 |
+
def _flush_votes(self, force: bool = False):
|
| 74 |
+
with self._votes_lock:
|
| 75 |
+
if not self._votes_buffer:
|
| 76 |
+
return
|
| 77 |
+
if not force and len(self._votes_buffer) < self._flush_every:
|
| 78 |
+
return
|
| 79 |
+
batch = list(self._votes_buffer)
|
| 80 |
+
self._votes_buffer.clear()
|
| 81 |
+
incoming = pd.DataFrame(batch)
|
| 82 |
+
for col in VOTE_COLUMNS:
|
| 83 |
+
if col not in incoming.columns:
|
| 84 |
+
incoming[col] = None
|
| 85 |
+
self._upload_votes_batch(incoming[VOTE_COLUMNS], commit_message=f"append {len(batch)} vote rows")
|
| 86 |
|
| 87 |
+
def _flush_loop(self):
|
| 88 |
+
while not self._stop_event.wait(self._flush_interval_sec):
|
| 89 |
+
self._flush_votes(force=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
def close(self):
|
| 92 |
+
if self._stop_event.is_set():
|
| 93 |
+
return
|
| 94 |
+
self._stop_event.set()
|
| 95 |
+
self._flush_thread.join(timeout=1.0)
|
| 96 |
+
self._flush_votes(force=True)
|
| 97 |
|
| 98 |
+
def append_vote_row(self, state: dict, winner: str | None):
|
| 99 |
id_a = int(state["id_a"])
|
| 100 |
id_b = int(state["id_b"])
|
| 101 |
+
winner_md5 = None
|
| 102 |
+
if winner == "A":
|
| 103 |
+
winner_md5 = state["key_a"]
|
| 104 |
+
elif winner == "B":
|
| 105 |
+
winner_md5 = state["key_b"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
vote_row = {
|
| 107 |
"vote_id": uuid.uuid4().hex,
|
| 108 |
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
|
|
|
|
| 115 |
"group": state["group"],
|
| 116 |
"session_id": state["session_id"],
|
| 117 |
}
|
| 118 |
+
with self._votes_lock:
|
|
|
|
| 119 |
self._votes_buffer.append(vote_row)
|
| 120 |
+
self._flush_votes()
|
|
|
|
|
|