Spaces:
Running
Running
deploy app, storage, readme
Browse files
app.py
CHANGED
|
@@ -1,192 +1,111 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import random
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import time
|
| 6 |
import threading
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _e621_media_url(row: dict) -> str | None:
|
| 46 |
-
# Prefer sample URL for faster client-side loading, fallback to original file URL.
|
| 47 |
-
sample_url = row.get("sample", {}).get("url")
|
| 48 |
-
file_url = row.get("file", {}).get("url")
|
| 49 |
-
if _is_e621_media_url(sample_url):
|
| 50 |
return sample_url
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _valid_image_post(row: dict) -> bool:
|
| 56 |
-
ext = row.get("file", {}).get("ext")
|
| 57 |
-
return ext in E621_IMAGE_EXTS and _e621_media_url(row) is not None
|
| 58 |
-
|
| 59 |
-
def _e621_fetch_pair(group_tags: list[str]) -> tuple:
|
| 60 |
-
posts = _e621_request(group_tags, limit=20)
|
| 61 |
-
valid = [p for p in posts if _valid_image_post(p)]
|
| 62 |
-
assert len(valid) >= 2, f"Not enough image posts for tags: {group_tags}"
|
| 63 |
-
row_a, row_b = valid[0], valid[1]
|
| 64 |
-
tags_a = set(row_a["tags"]["general"] + row_a["tags"].get("species", []) + row_a["tags"].get("character", []))
|
| 65 |
-
tags_b = set(row_b["tags"]["general"] + row_b["tags"].get("species", []) + row_b["tags"].get("character", []))
|
| 66 |
-
common = sorted(tags_a & tags_b)
|
| 67 |
-
return row_a, row_b, common
|
| 68 |
|
| 69 |
DATASETS: dict[str, dict] = {
|
| 70 |
-
"
|
| 71 |
-
"fetch_pair":
|
| 72 |
-
"get_id": lambda row: row["
|
| 73 |
-
"get_image":
|
| 74 |
-
"groups": {
|
| 75 |
-
"e_male": ["male", "solo", "rating:e"],
|
| 76 |
-
"e_female": ["female", "solo", "rating:e"],
|
| 77 |
-
"e_male_female": ["male", "female", "rating:e"],
|
| 78 |
-
"q_male": ["male", "solo", "rating:q"],
|
| 79 |
-
"q_female": ["female", "solo", "rating:q"],
|
| 80 |
-
"s_male": ["male", "solo", "rating:s"],
|
| 81 |
-
"s_female": ["female", "solo", "rating:s"],
|
| 82 |
-
},
|
| 83 |
},
|
| 84 |
}
|
| 85 |
-
|
| 86 |
-
RATINGS_FILE = "elo_ratings.json"
|
| 87 |
-
DEFAULT_ELO = 1500
|
| 88 |
-
K = 32
|
| 89 |
-
RATINGS_MEM: dict[str, int] = {}
|
| 90 |
-
RATING_PREFIX = {
|
| 91 |
-
"safe": "s_",
|
| 92 |
-
"questionable": "q_",
|
| 93 |
-
"explicit": "e_",
|
| 94 |
-
"all": None,
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
# -- Prefetch ---------------------------------------------------------------
|
| 98 |
-
|
| 99 |
-
_prefetch: dict[tuple[str, str], tuple | None] = {}
|
| 100 |
-
_prefetch_threads: dict[tuple[str, str], threading.Thread] = {}
|
| 101 |
|
| 102 |
def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
|
| 103 |
-
prefix = RATING_PREFIX[rating_pref]
|
| 104 |
groups = list(cfg["groups"].keys())
|
| 105 |
-
if
|
| 106 |
return groups
|
| 107 |
-
return [g for g in groups if g.
|
| 108 |
-
|
| 109 |
-
def _do_prefetch(dataset_name: str, rating_pref: str):
|
| 110 |
-
try:
|
| 111 |
-
cfg = DATASETS[dataset_name]
|
| 112 |
-
groups = _select_groups(cfg, rating_pref)
|
| 113 |
-
assert groups, f"No groups for rating preference: {rating_pref}"
|
| 114 |
-
group = random.choice(groups)
|
| 115 |
-
row_a, row_b, common = cfg["fetch_pair"](cfg["groups"][group])
|
| 116 |
-
_prefetch[(dataset_name, rating_pref)] = (row_a, row_b, common, group)
|
| 117 |
-
except Exception:
|
| 118 |
-
_prefetch[(dataset_name, rating_pref)] = None
|
| 119 |
-
|
| 120 |
-
def prefetch(dataset_name: str, rating_pref: str):
|
| 121 |
-
key = (dataset_name, rating_pref)
|
| 122 |
-
_prefetch[key] = None
|
| 123 |
-
t = threading.Thread(target=_do_prefetch, args=(dataset_name, rating_pref), daemon=True)
|
| 124 |
-
_prefetch_threads[key] = t
|
| 125 |
-
t.start()
|
| 126 |
-
|
| 127 |
-
def consume_prefetch(dataset_name: str, rating_pref: str) -> tuple:
|
| 128 |
-
key = (dataset_name, rating_pref)
|
| 129 |
-
# Wait for prefetch to finish (should be near-instant since we started it earlier)
|
| 130 |
-
t = _prefetch_threads.get(key)
|
| 131 |
-
if t:
|
| 132 |
-
t.join(timeout=15)
|
| 133 |
-
result = _prefetch.pop(key, None)
|
| 134 |
-
# Kick off the next prefetch immediately
|
| 135 |
-
prefetch(dataset_name, rating_pref)
|
| 136 |
-
if result is not None:
|
| 137 |
-
return result
|
| 138 |
-
# Fallback: fetch synchronously if prefetch failed
|
| 139 |
-
cfg = DATASETS[dataset_name]
|
| 140 |
-
groups = _select_groups(cfg, rating_pref)
|
| 141 |
-
assert groups, f"No groups for rating preference: {rating_pref}"
|
| 142 |
-
group = random.choice(groups)
|
| 143 |
-
row_a, row_b, common = cfg["fetch_pair"](cfg["groups"][group])
|
| 144 |
-
return row_a, row_b, common, group
|
| 145 |
-
|
| 146 |
-
# -- ELO helpers ------------------------------------------------------------
|
| 147 |
-
|
| 148 |
-
def load_ratings() -> dict:
|
| 149 |
-
return RATINGS_MEM.copy()
|
| 150 |
-
|
| 151 |
-
def save_ratings(ratings: dict):
|
| 152 |
-
# Stubbed persistence for now: keep ratings only in memory.
|
| 153 |
-
RATINGS_MEM.clear()
|
| 154 |
-
RATINGS_MEM.update(ratings)
|
| 155 |
-
|
| 156 |
-
def elo_update(ratings: dict, winner_key: str, loser_key: str) -> dict:
|
| 157 |
-
rw = ratings.get(winner_key, DEFAULT_ELO)
|
| 158 |
-
rl = ratings.get(loser_key, DEFAULT_ELO)
|
| 159 |
-
ea = 1 / (1 + 10 ** ((rl - rw) / 400))
|
| 160 |
-
ratings[winner_key] = round(rw + K * (1 - ea))
|
| 161 |
-
ratings[loser_key] = round(rl + K * (0 - (1 - ea)))
|
| 162 |
-
return ratings
|
| 163 |
|
| 164 |
def _commit_oldest_pending(state: dict):
|
| 165 |
pending = state.setdefault("pending", [])
|
| 166 |
-
if len(pending) <=
|
| 167 |
return
|
| 168 |
oldest = pending.pop(0)
|
| 169 |
-
|
| 170 |
-
if winner is None:
|
| 171 |
-
return
|
| 172 |
-
ratings = load_ratings()
|
| 173 |
-
if winner == "A":
|
| 174 |
-
ratings = elo_update(ratings, oldest["key_a"], oldest["key_b"])
|
| 175 |
-
else:
|
| 176 |
-
ratings = elo_update(ratings, oldest["key_b"], oldest["key_a"])
|
| 177 |
-
save_ratings(ratings)
|
| 178 |
|
| 179 |
def _render_current(state: dict) -> tuple:
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
# -- Gradio callbacks -------------------------------------------------------
|
| 183 |
|
| 184 |
def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
|
| 185 |
cfg = DATASETS[dataset_name]
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
key_a = cfg["get_id"](row_a)
|
| 188 |
key_b = cfg["get_id"](row_b)
|
| 189 |
-
|
|
|
|
|
|
|
| 190 |
url_a = cfg["get_image"](row_a)
|
| 191 |
url_b = cfg["get_image"](row_b)
|
| 192 |
state["url_a"] = url_a
|
|
@@ -194,19 +113,24 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
|
|
| 194 |
return _render_current(state)
|
| 195 |
|
| 196 |
def _queue_decision(winner: str | None, state: dict):
|
|
|
|
| 197 |
state.setdefault("pending", [])
|
| 198 |
-
state.setdefault("decision_history", [])
|
| 199 |
decision = {
|
| 200 |
"winner": winner,
|
| 201 |
"key_a": state["key_a"],
|
| 202 |
"key_b": state["key_b"],
|
|
|
|
|
|
|
| 203 |
"url_a": state["url_a"],
|
| 204 |
"url_b": state["url_b"],
|
| 205 |
"dataset": state["dataset"],
|
| 206 |
"rating_pref": state["rating_pref"],
|
|
|
|
|
|
|
| 207 |
}
|
| 208 |
state["pending"].append(decision)
|
| 209 |
-
state["
|
|
|
|
| 210 |
_commit_oldest_pending(state)
|
| 211 |
|
| 212 |
def vote(winner: str | None, state: dict) -> tuple:
|
|
@@ -215,31 +139,58 @@ def vote(winner: str | None, state: dict) -> tuple:
|
|
| 215 |
return new_round(state["dataset"], state["rating_pref"], state)
|
| 216 |
|
| 217 |
def go_back(state: dict) -> tuple:
|
| 218 |
-
history = state.setdefault("decision_history", [])
|
| 219 |
pending = state.setdefault("pending", [])
|
| 220 |
-
if not
|
| 221 |
return _render_current(state)
|
| 222 |
-
last =
|
| 223 |
-
if
|
|
|
|
|
|
|
|
|
|
| 224 |
pending.pop()
|
|
|
|
|
|
|
| 225 |
state.update(
|
| 226 |
dataset=last["dataset"],
|
| 227 |
rating_pref=last["rating_pref"],
|
| 228 |
key_a=last["key_a"],
|
| 229 |
key_b=last["key_b"],
|
|
|
|
|
|
|
| 230 |
url_a=last["url_a"],
|
| 231 |
url_b=last["url_b"],
|
|
|
|
| 232 |
)
|
| 233 |
return _render_current(state)
|
| 234 |
|
| 235 |
-
# Warm up prefetch for all datasets at startup (safe by default)
|
| 236 |
-
for _ds in DATASETS:
|
| 237 |
-
prefetch(_ds, "safe")
|
| 238 |
-
|
| 239 |
# -- UI ---------------------------------------------------------------------
|
| 240 |
|
| 241 |
with gr.Blocks(
|
| 242 |
title="Image Rater",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
css="""
|
| 244 |
.subtle-link button {
|
| 245 |
background: none !important;
|
|
@@ -255,42 +206,95 @@ with gr.Blocks(
|
|
| 255 |
.subtle-link button:hover {
|
| 256 |
color: #5a5a5a !important;
|
| 257 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
""",
|
| 259 |
) as demo:
|
| 260 |
-
gr.Markdown("## Image Rater\
|
| 261 |
|
| 262 |
state = gr.State({})
|
|
|
|
| 263 |
|
| 264 |
with gr.Row():
|
| 265 |
-
img_a = gr.
|
| 266 |
-
img_b = gr.
|
| 267 |
|
| 268 |
with gr.Row():
|
| 269 |
-
btn_a = gr.Button("👍 Prefer A", variant="primary")
|
| 270 |
-
btn_skip = gr.Button("
|
| 271 |
-
btn_b = gr.Button("👍 Prefer B", variant="primary")
|
| 272 |
|
| 273 |
with gr.Accordion("Settings", open=False):
|
| 274 |
gr.Markdown("<span style='color:#888;font-size:0.9em;'>Advanced options</span>")
|
| 275 |
rating_dd = gr.Dropdown(
|
| 276 |
-
choices=["safe", "
|
| 277 |
value="safe",
|
| 278 |
label="Rating",
|
|
|
|
| 279 |
)
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
|
| 295 |
if __name__ == "__main__":
|
| 296 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import random
|
|
|
|
|
|
|
|
|
|
| 3 |
import threading
|
| 4 |
+
import uuid
|
| 5 |
+
import os
|
| 6 |
+
import html
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
from storage import VoteStorage
|
| 13 |
+
|
| 14 |
+
LOCAL_DATA_DIR = 'data'
|
| 15 |
+
DEBUG_MODE = os.getenv("DEBUG", "0").lower() in ("1", "true", "yes", "on")
|
| 16 |
+
VOTE_STORAGE = VoteStorage(mode="local" if DEBUG_MODE else "hf", local_dir=LOCAL_DATA_DIR)
|
| 17 |
+
|
| 18 |
+
# -- Pool dataset -----------------------------------------------------------
|
| 19 |
+
if DEBUG_MODE:
|
| 20 |
+
_pool_path = str(Path(__file__).resolve().parent / LOCAL_DATA_DIR / "pool.parquet")
|
| 21 |
+
assert Path(_pool_path).exists(), f"Missing local debug pool file: {_pool_path}"
|
| 22 |
+
else:
|
| 23 |
+
_pool_path = hf_hub_download(
|
| 24 |
+
repo_id="taigasan/e6-visual-ratings",
|
| 25 |
+
filename="pool.parquet",
|
| 26 |
+
repo_type="dataset",
|
| 27 |
+
)
|
| 28 |
+
_pool_df = pd.read_parquet(_pool_path)
|
| 29 |
+
_pool_group_dfs = {g: gdf for g, gdf in _pool_df.groupby("group")}
|
| 30 |
+
|
| 31 |
+
def _pool_fetch_pair(group_name: str) -> tuple:
|
| 32 |
+
gdf = _pool_group_dfs[group_name]
|
| 33 |
+
assert len(gdf) >= 2, f"Not enough rows for group: {group_name}"
|
| 34 |
+
sample = gdf.sample(2)
|
| 35 |
+
return sample.iloc[0], sample.iloc[1]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _row_image_url(row) -> str:
|
| 39 |
+
sample_url = row.get("sample_url")
|
| 40 |
+
if isinstance(sample_url, str) and sample_url:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return sample_url
|
| 42 |
+
image_url = row.get("image_url")
|
| 43 |
+
if isinstance(image_url, str) and image_url:
|
| 44 |
+
return image_url
|
| 45 |
+
return ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
DATASETS: dict[str, dict] = {
|
| 48 |
+
"pool": {
|
| 49 |
+
"fetch_pair": _pool_fetch_pair,
|
| 50 |
+
"get_id": lambda row: row["md5"],
|
| 51 |
+
"get_image": _row_image_url,
|
| 52 |
+
"groups": {g: g for g in sorted(_pool_df["group"].unique())},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
},
|
| 54 |
}
|
| 55 |
+
DEFAULT_DATASET = list(DATASETS.keys())[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
|
|
|
|
| 58 |
groups = list(cfg["groups"].keys())
|
| 59 |
+
if rating_pref == "all":
|
| 60 |
return groups
|
| 61 |
+
return [g for g in groups if g.endswith(f"_{rating_pref}")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def _commit_oldest_pending(state: dict):
|
| 64 |
pending = state.setdefault("pending", [])
|
| 65 |
+
if len(pending) <= 1:
|
| 66 |
return
|
| 67 |
oldest = pending.pop(0)
|
| 68 |
+
threading.Thread(target=VOTE_STORAGE.append_vote_row, args=(oldest.copy(), oldest.get("winner")), daemon=True).start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
def _render_current(state: dict) -> tuple:
|
| 71 |
+
img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>Image A</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
|
| 72 |
+
img_b_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>Image B</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_b'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
|
| 73 |
+
link_a = f"Image A: https://e621.net/posts/{state['id_a']}"
|
| 74 |
+
link_b = f"Image B: https://e621.net/posts/{state['id_b']}"
|
| 75 |
+
can_go_back = bool(state.get("can_go_back"))
|
| 76 |
+
back_md = "[back](#back)" if can_go_back else "<span class='subtle-back-link-disabled'>back</span>"
|
| 77 |
+
details = f"<span class='subtle-note'>Group: {state['group']}</span>"
|
| 78 |
+
return img_a_html, img_b_html, link_a, link_b, back_md, details, state
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _normalize_rating_pref(pref: str | None) -> str:
|
| 83 |
+
return pref if pref in ("safe", "all") else "safe"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _initial_load(state: dict, pref: str | None):
|
| 87 |
+
rating_pref = _normalize_rating_pref(pref)
|
| 88 |
+
return rating_pref, *new_round(DEFAULT_DATASET, rating_pref, state)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _on_rating_change(rating_pref: str, state: dict):
|
| 92 |
+
rating_pref = _normalize_rating_pref(rating_pref)
|
| 93 |
+
return *new_round(DEFAULT_DATASET, rating_pref, state), rating_pref
|
| 94 |
|
| 95 |
# -- Gradio callbacks -------------------------------------------------------
|
| 96 |
|
| 97 |
def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
|
| 98 |
cfg = DATASETS[dataset_name]
|
| 99 |
+
groups = _select_groups(cfg, rating_pref)
|
| 100 |
+
assert groups, f"No groups for rating preference: {rating_pref}"
|
| 101 |
+
group = random.choice(groups)
|
| 102 |
+
row_a, row_b = cfg["fetch_pair"](cfg["groups"][group])
|
| 103 |
+
state.setdefault("session_id", uuid.uuid4().hex)
|
| 104 |
key_a = cfg["get_id"](row_a)
|
| 105 |
key_b = cfg["get_id"](row_b)
|
| 106 |
+
id_a = int(row_a["id"])
|
| 107 |
+
id_b = int(row_b["id"])
|
| 108 |
+
state.update(dataset=dataset_name, rating_pref=rating_pref, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group)
|
| 109 |
url_a = cfg["get_image"](row_a)
|
| 110 |
url_b = cfg["get_image"](row_b)
|
| 111 |
state["url_a"] = url_a
|
|
|
|
| 113 |
return _render_current(state)
|
| 114 |
|
| 115 |
def _queue_decision(winner: str | None, state: dict):
|
| 116 |
+
assert state.get("session_id"), "Missing session_id: refusing to record vote"
|
| 117 |
state.setdefault("pending", [])
|
|
|
|
| 118 |
decision = {
|
| 119 |
"winner": winner,
|
| 120 |
"key_a": state["key_a"],
|
| 121 |
"key_b": state["key_b"],
|
| 122 |
+
"id_a": state["id_a"],
|
| 123 |
+
"id_b": state["id_b"],
|
| 124 |
"url_a": state["url_a"],
|
| 125 |
"url_b": state["url_b"],
|
| 126 |
"dataset": state["dataset"],
|
| 127 |
"rating_pref": state["rating_pref"],
|
| 128 |
+
"group": state["group"],
|
| 129 |
+
"session_id": state["session_id"],
|
| 130 |
}
|
| 131 |
state["pending"].append(decision)
|
| 132 |
+
state["last_decision"] = decision
|
| 133 |
+
state["can_go_back"] = True
|
| 134 |
_commit_oldest_pending(state)
|
| 135 |
|
| 136 |
def vote(winner: str | None, state: dict) -> tuple:
|
|
|
|
| 139 |
return new_round(state["dataset"], state["rating_pref"], state)
|
| 140 |
|
| 141 |
def go_back(state: dict) -> tuple:
|
|
|
|
| 142 |
pending = state.setdefault("pending", [])
|
| 143 |
+
if not state.get("can_go_back"):
|
| 144 |
return _render_current(state)
|
| 145 |
+
last = state.get("last_decision")
|
| 146 |
+
if not last:
|
| 147 |
+
state["can_go_back"] = False
|
| 148 |
+
return _render_current(state)
|
| 149 |
+
if pending and pending[-1] == last:
|
| 150 |
pending.pop()
|
| 151 |
+
state["can_go_back"] = False
|
| 152 |
+
state["last_decision"] = None
|
| 153 |
state.update(
|
| 154 |
dataset=last["dataset"],
|
| 155 |
rating_pref=last["rating_pref"],
|
| 156 |
key_a=last["key_a"],
|
| 157 |
key_b=last["key_b"],
|
| 158 |
+
id_a=last["id_a"],
|
| 159 |
+
id_b=last["id_b"],
|
| 160 |
url_a=last["url_a"],
|
| 161 |
url_b=last["url_b"],
|
| 162 |
+
group=last["group"],
|
| 163 |
)
|
| 164 |
return _render_current(state)
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
# -- UI ---------------------------------------------------------------------
|
| 167 |
|
| 168 |
with gr.Blocks(
|
| 169 |
title="Image Rater",
|
| 170 |
+
head="""
|
| 171 |
+
<script>
|
| 172 |
+
window.addEventListener('keydown', function (e) {
|
| 173 |
+
const t = e.target;
|
| 174 |
+
if (t && (t.tagName === 'INPUT' || t.tagName === 'TEXTAREA' || t.isContentEditable)) return;
|
| 175 |
+
if (e.key === 'ArrowLeft') {
|
| 176 |
+
e.preventDefault();
|
| 177 |
+
document.querySelector('#btn-vote-a button, button#btn-vote-a')?.click();
|
| 178 |
+
} else if (e.key === 'ArrowRight') {
|
| 179 |
+
e.preventDefault();
|
| 180 |
+
document.querySelector('#btn-vote-b button, button#btn-vote-b')?.click();
|
| 181 |
+
} else if (e.key === 'Backspace') {
|
| 182 |
+
e.preventDefault();
|
| 183 |
+
document.querySelector('#btn-back-action button, button#btn-back-action')?.click();
|
| 184 |
+
}
|
| 185 |
+
});
|
| 186 |
+
document.addEventListener('click', function (e) {
|
| 187 |
+
const a = e.target.closest('a[href="#back"]');
|
| 188 |
+
if (!a) return;
|
| 189 |
+
e.preventDefault();
|
| 190 |
+
document.querySelector('#btn-back-action button, button#btn-back-action')?.click();
|
| 191 |
+
});
|
| 192 |
+
</script>
|
| 193 |
+
""",
|
| 194 |
css="""
|
| 195 |
.subtle-link button {
|
| 196 |
background: none !important;
|
|
|
|
| 206 |
.subtle-link button:hover {
|
| 207 |
color: #5a5a5a !important;
|
| 208 |
}
|
| 209 |
+
.subtle-link {
|
| 210 |
+
width: fit-content !important;
|
| 211 |
+
}
|
| 212 |
+
.subtle-link button {
|
| 213 |
+
width: fit-content !important;
|
| 214 |
+
}
|
| 215 |
+
.subtle-note {
|
| 216 |
+
color: #888;
|
| 217 |
+
font-size: 0.9em;
|
| 218 |
+
}
|
| 219 |
+
.rating-card {
|
| 220 |
+
width: 100%;
|
| 221 |
+
}
|
| 222 |
+
.rating-card-title {
|
| 223 |
+
min-height: 24px;
|
| 224 |
+
margin-bottom: 8px;
|
| 225 |
+
}
|
| 226 |
+
.rating-image-frame {
|
| 227 |
+
width: 100%;
|
| 228 |
+
height: 512px;
|
| 229 |
+
border: 1px solid #e6e6e6;
|
| 230 |
+
border-radius: 8px;
|
| 231 |
+
background: #fafafa;
|
| 232 |
+
display: flex;
|
| 233 |
+
align-items: center;
|
| 234 |
+
justify-content: center;
|
| 235 |
+
overflow: hidden;
|
| 236 |
+
}
|
| 237 |
+
.rating-image {
|
| 238 |
+
width: 100%;
|
| 239 |
+
height: 100%;
|
| 240 |
+
object-fit: contain;
|
| 241 |
+
}
|
| 242 |
+
.subtle-back-link-wrap a {
|
| 243 |
+
color: #7a7a7a !important;
|
| 244 |
+
font-size: 0.9em;
|
| 245 |
+
text-decoration: underline;
|
| 246 |
+
}
|
| 247 |
+
.subtle-back-link-wrap a:hover {
|
| 248 |
+
color: #5a5a5a !important;
|
| 249 |
+
}
|
| 250 |
+
.subtle-back-link-disabled {
|
| 251 |
+
color: #b8b8b8 !important;
|
| 252 |
+
pointer-events: none;
|
| 253 |
+
text-decoration: none;
|
| 254 |
+
}
|
| 255 |
+
.hidden-action-btn {
|
| 256 |
+
display: none !important;
|
| 257 |
+
}
|
| 258 |
""",
|
| 259 |
) as demo:
|
| 260 |
+
gr.Markdown("## Image Quality Rater\nRate relative image quality. Choose the image with better quality, or select same quality if they are comparable. Both images are drawn from the same group to avoid cross-group bias.")
|
| 261 |
|
| 262 |
state = gr.State({})
|
| 263 |
+
rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref")
|
| 264 |
|
| 265 |
with gr.Row():
|
| 266 |
+
img_a = gr.HTML()
|
| 267 |
+
img_b = gr.HTML()
|
| 268 |
|
| 269 |
with gr.Row():
|
| 270 |
+
btn_a = gr.Button("👍 Prefer A", variant="primary", elem_id="btn-vote-a")
|
| 271 |
+
btn_skip = gr.Button("Same quality", elem_id="btn-skip")
|
| 272 |
+
btn_b = gr.Button("👍 Prefer B", variant="primary", elem_id="btn-vote-b")
|
| 273 |
|
| 274 |
with gr.Accordion("Settings", open=False):
|
| 275 |
gr.Markdown("<span style='color:#888;font-size:0.9em;'>Advanced options</span>")
|
| 276 |
rating_dd = gr.Dropdown(
|
| 277 |
+
choices=["safe", "all"],
|
| 278 |
value="safe",
|
| 279 |
label="Rating",
|
| 280 |
+
elem_id="rating-pref",
|
| 281 |
)
|
| 282 |
+
|
| 283 |
+
link_a = gr.Markdown(label="Image A link")
|
| 284 |
+
link_b = gr.Markdown(label="Image B link")
|
| 285 |
+
back_link = gr.Markdown(elem_classes=["subtle-back-link-wrap"])
|
| 286 |
+
btn_back_action = gr.Button("back", elem_id="btn-back-action", elem_classes=["hidden-action-btn"])
|
| 287 |
+
details_md = gr.Markdown()
|
| 288 |
+
gr.Markdown("<span class='subtle-note'>Dataset: <a href='https://huggingface.co/datasets/taigasan/e6-visual-ratings' target='_blank' rel='noopener noreferrer'>taigasan/e6-visual-ratings</a></span>")
|
| 289 |
+
gr.Markdown("<span class='subtle-note'>Shortcuts: Left = vote A, Right = vote B, Backspace = back</span>")
|
| 290 |
+
outputs = [img_a, img_b, link_a, link_b, back_link, details_md, state]
|
| 291 |
+
|
| 292 |
+
btn_a.click(fn=lambda s: vote("A", s), inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
| 293 |
+
btn_b.click(fn=lambda s: vote("B", s), inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
| 294 |
+
btn_skip.click(fn=lambda s: vote(None, s), inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
| 295 |
+
btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
| 296 |
+
rating_dd.change(fn=_on_rating_change, inputs=[rating_dd, state], outputs=[*outputs, rating_pref_store], queue=False, show_progress="hidden")
|
| 297 |
+
demo.load(fn=_initial_load, inputs=[state, rating_pref_store], outputs=[rating_dd, *outputs], queue=False, show_progress="hidden")
|
| 298 |
|
| 299 |
if __name__ == "__main__":
|
| 300 |
demo.launch()
|