taigasan commited on
Commit
982d7de
·
verified ·
1 Parent(s): 40e6ade

deploy app, storage, readme

Browse files
Files changed (1) hide show
  1. app.py +198 -194
app.py CHANGED
@@ -1,192 +1,111 @@
1
  import gradio as gr
2
  import random
3
- import json
4
- import os
5
- import time
6
  import threading
7
- import requests
8
-
9
- # -- Dataset config ---------------------------------------------------------
10
- # Each entry:
11
- # fetch_pair : callable(group_name, group_cfg) -> (row_a, row_b, common_tags)
12
- # get_id : callable(row) -> unique image identifier (ELO key)
13
- # get_image : callable(row) -> image URL
14
- # groups : dict of group_name -> list of API tag strings
15
- # Both images in a pair come from the same group — prevents
16
- # cross-group bias (e.g. rating male vs female images).
17
-
18
- E621_API = "https://e621.net/posts.json"
19
- E621_HEADERS = {"User-Agent": "ImageRater/1.0 (HuggingFace Space)"}
20
- E621_MEDIA_PREFIXES = ("https://static1.e621.net/", "https://static1.e926.net/")
21
- E621_IMAGE_EXTS = {"jpg", "jpeg", "png", "gif", "webp"}
22
- _e621_last_req = 0.0
23
- _e621_lock = threading.Lock()
24
-
25
- def _e621_request(tags: list[str], limit: int) -> list[dict]:
26
- global _e621_last_req
27
- with _e621_lock:
28
- # respect sustained ≤1 req/sec
29
- gap = time.monotonic() - _e621_last_req
30
- if gap < 1.0:
31
- time.sleep(1.0 - gap)
32
- resp = requests.get(
33
- E621_API,
34
- headers=E621_HEADERS,
35
- params={"tags": " ".join(tags + ["order:random"]), "limit": limit},
36
- timeout=10,
37
- )
38
- _e621_last_req = time.monotonic()
39
- resp.raise_for_status()
40
- return resp.json()["posts"]
41
-
42
- def _is_e621_media_url(url: str | None) -> bool:
43
- return isinstance(url, str) and url.startswith(E621_MEDIA_PREFIXES)
44
-
45
- def _e621_media_url(row: dict) -> str | None:
46
- # Prefer sample URL for faster client-side loading, fallback to original file URL.
47
- sample_url = row.get("sample", {}).get("url")
48
- file_url = row.get("file", {}).get("url")
49
- if _is_e621_media_url(sample_url):
50
  return sample_url
51
- if _is_e621_media_url(file_url):
52
- return file_url
53
- return None
54
-
55
- def _valid_image_post(row: dict) -> bool:
56
- ext = row.get("file", {}).get("ext")
57
- return ext in E621_IMAGE_EXTS and _e621_media_url(row) is not None
58
-
59
- def _e621_fetch_pair(group_tags: list[str]) -> tuple:
60
- posts = _e621_request(group_tags, limit=20)
61
- valid = [p for p in posts if _valid_image_post(p)]
62
- assert len(valid) >= 2, f"Not enough image posts for tags: {group_tags}"
63
- row_a, row_b = valid[0], valid[1]
64
- tags_a = set(row_a["tags"]["general"] + row_a["tags"].get("species", []) + row_a["tags"].get("character", []))
65
- tags_b = set(row_b["tags"]["general"] + row_b["tags"].get("species", []) + row_b["tags"].get("character", []))
66
- common = sorted(tags_a & tags_b)
67
- return row_a, row_b, common
68
 
69
  DATASETS: dict[str, dict] = {
70
- "e621": {
71
- "fetch_pair": _e621_fetch_pair,
72
- "get_id": lambda row: row["file"]["md5"],
73
- "get_image": _e621_media_url,
74
- "groups": {
75
- "e_male": ["male", "solo", "rating:e"],
76
- "e_female": ["female", "solo", "rating:e"],
77
- "e_male_female": ["male", "female", "rating:e"],
78
- "q_male": ["male", "solo", "rating:q"],
79
- "q_female": ["female", "solo", "rating:q"],
80
- "s_male": ["male", "solo", "rating:s"],
81
- "s_female": ["female", "solo", "rating:s"],
82
- },
83
  },
84
  }
85
-
86
- RATINGS_FILE = "elo_ratings.json"
87
- DEFAULT_ELO = 1500
88
- K = 32
89
- RATINGS_MEM: dict[str, int] = {}
90
- RATING_PREFIX = {
91
- "safe": "s_",
92
- "questionable": "q_",
93
- "explicit": "e_",
94
- "all": None,
95
- }
96
-
97
- # -- Prefetch ---------------------------------------------------------------
98
-
99
- _prefetch: dict[tuple[str, str], tuple | None] = {}
100
- _prefetch_threads: dict[tuple[str, str], threading.Thread] = {}
101
 
102
  def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
103
- prefix = RATING_PREFIX[rating_pref]
104
  groups = list(cfg["groups"].keys())
105
- if prefix is None:
106
  return groups
107
- return [g for g in groups if g.startswith(prefix)]
108
-
109
- def _do_prefetch(dataset_name: str, rating_pref: str):
110
- try:
111
- cfg = DATASETS[dataset_name]
112
- groups = _select_groups(cfg, rating_pref)
113
- assert groups, f"No groups for rating preference: {rating_pref}"
114
- group = random.choice(groups)
115
- row_a, row_b, common = cfg["fetch_pair"](cfg["groups"][group])
116
- _prefetch[(dataset_name, rating_pref)] = (row_a, row_b, common, group)
117
- except Exception:
118
- _prefetch[(dataset_name, rating_pref)] = None
119
-
120
- def prefetch(dataset_name: str, rating_pref: str):
121
- key = (dataset_name, rating_pref)
122
- _prefetch[key] = None
123
- t = threading.Thread(target=_do_prefetch, args=(dataset_name, rating_pref), daemon=True)
124
- _prefetch_threads[key] = t
125
- t.start()
126
-
127
- def consume_prefetch(dataset_name: str, rating_pref: str) -> tuple:
128
- key = (dataset_name, rating_pref)
129
- # Wait for prefetch to finish (should be near-instant since we started it earlier)
130
- t = _prefetch_threads.get(key)
131
- if t:
132
- t.join(timeout=15)
133
- result = _prefetch.pop(key, None)
134
- # Kick off the next prefetch immediately
135
- prefetch(dataset_name, rating_pref)
136
- if result is not None:
137
- return result
138
- # Fallback: fetch synchronously if prefetch failed
139
- cfg = DATASETS[dataset_name]
140
- groups = _select_groups(cfg, rating_pref)
141
- assert groups, f"No groups for rating preference: {rating_pref}"
142
- group = random.choice(groups)
143
- row_a, row_b, common = cfg["fetch_pair"](cfg["groups"][group])
144
- return row_a, row_b, common, group
145
-
146
- # -- ELO helpers ------------------------------------------------------------
147
-
148
- def load_ratings() -> dict:
149
- return RATINGS_MEM.copy()
150
-
151
- def save_ratings(ratings: dict):
152
- # Stubbed persistence for now: keep ratings only in memory.
153
- RATINGS_MEM.clear()
154
- RATINGS_MEM.update(ratings)
155
-
156
- def elo_update(ratings: dict, winner_key: str, loser_key: str) -> dict:
157
- rw = ratings.get(winner_key, DEFAULT_ELO)
158
- rl = ratings.get(loser_key, DEFAULT_ELO)
159
- ea = 1 / (1 + 10 ** ((rl - rw) / 400))
160
- ratings[winner_key] = round(rw + K * (1 - ea))
161
- ratings[loser_key] = round(rl + K * (0 - (1 - ea)))
162
- return ratings
163
 
164
  def _commit_oldest_pending(state: dict):
165
  pending = state.setdefault("pending", [])
166
- if len(pending) <= 2:
167
  return
168
  oldest = pending.pop(0)
169
- winner = oldest.get("winner")
170
- if winner is None:
171
- return
172
- ratings = load_ratings()
173
- if winner == "A":
174
- ratings = elo_update(ratings, oldest["key_a"], oldest["key_b"])
175
- else:
176
- ratings = elo_update(ratings, oldest["key_b"], oldest["key_a"])
177
- save_ratings(ratings)
178
 
179
  def _render_current(state: dict) -> tuple:
180
- return state["url_a"], state["url_b"], state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  # -- Gradio callbacks -------------------------------------------------------
183
 
184
  def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
185
  cfg = DATASETS[dataset_name]
186
- row_a, row_b, common, group = consume_prefetch(dataset_name, rating_pref)
 
 
 
 
187
  key_a = cfg["get_id"](row_a)
188
  key_b = cfg["get_id"](row_b)
189
- state.update(dataset=dataset_name, rating_pref=rating_pref, key_a=key_a, key_b=key_b)
 
 
190
  url_a = cfg["get_image"](row_a)
191
  url_b = cfg["get_image"](row_b)
192
  state["url_a"] = url_a
@@ -194,19 +113,24 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
194
  return _render_current(state)
195
 
196
  def _queue_decision(winner: str | None, state: dict):
 
197
  state.setdefault("pending", [])
198
- state.setdefault("decision_history", [])
199
  decision = {
200
  "winner": winner,
201
  "key_a": state["key_a"],
202
  "key_b": state["key_b"],
 
 
203
  "url_a": state["url_a"],
204
  "url_b": state["url_b"],
205
  "dataset": state["dataset"],
206
  "rating_pref": state["rating_pref"],
 
 
207
  }
208
  state["pending"].append(decision)
209
- state["decision_history"].append(decision)
 
210
  _commit_oldest_pending(state)
211
 
212
  def vote(winner: str | None, state: dict) -> tuple:
@@ -215,31 +139,58 @@ def vote(winner: str | None, state: dict) -> tuple:
215
  return new_round(state["dataset"], state["rating_pref"], state)
216
 
217
  def go_back(state: dict) -> tuple:
218
- history = state.setdefault("decision_history", [])
219
  pending = state.setdefault("pending", [])
220
- if not history:
221
  return _render_current(state)
222
- last = history.pop()
223
- if pending and pending[-1] is last:
 
 
 
224
  pending.pop()
 
 
225
  state.update(
226
  dataset=last["dataset"],
227
  rating_pref=last["rating_pref"],
228
  key_a=last["key_a"],
229
  key_b=last["key_b"],
 
 
230
  url_a=last["url_a"],
231
  url_b=last["url_b"],
 
232
  )
233
  return _render_current(state)
234
 
235
- # Warm up prefetch for all datasets at startup (safe by default)
236
- for _ds in DATASETS:
237
- prefetch(_ds, "safe")
238
-
239
  # -- UI ---------------------------------------------------------------------
240
 
241
  with gr.Blocks(
242
  title="Image Rater",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  css="""
244
  .subtle-link button {
245
  background: none !important;
@@ -255,42 +206,95 @@ with gr.Blocks(
255
  .subtle-link button:hover {
256
  color: #5a5a5a !important;
257
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  """,
259
  ) as demo:
260
- gr.Markdown("## Image Rater\nPick the image you prefer. Both images are drawn from the same group to avoid cross-group bias.")
261
 
262
  state = gr.State({})
 
263
 
264
  with gr.Row():
265
- img_a = gr.Image(label="Image A", interactive=False, height=512)
266
- img_b = gr.Image(label="Image B", interactive=False, height=512)
267
 
268
  with gr.Row():
269
- btn_a = gr.Button("👍 Prefer A", variant="primary")
270
- btn_skip = gr.Button("About the same")
271
- btn_b = gr.Button("👍 Prefer B", variant="primary")
272
 
273
  with gr.Accordion("Settings", open=False):
274
  gr.Markdown("<span style='color:#888;font-size:0.9em;'>Advanced options</span>")
275
  rating_dd = gr.Dropdown(
276
- choices=["safe", "questionable", "explicit", "all"],
277
  value="safe",
278
  label="Rating",
 
279
  )
280
- dataset_dd = gr.Dropdown(choices=list(DATASETS.keys()), value=list(DATASETS.keys())[0], label="Dataset")
281
- btn_back = gr.Button("Back", elem_classes=["subtle-link"])
282
-
283
- outputs = [img_a, img_b, state]
284
-
285
- btn_a.click(fn=lambda s: vote("A", s), inputs=[state], outputs=outputs)
286
- btn_b.click(fn=lambda s: vote("B", s), inputs=[state], outputs=outputs)
287
- btn_skip.click(fn=lambda s: vote(None, s), inputs=[state], outputs=outputs)
288
- btn_back.click(fn=go_back, inputs=[state], outputs=outputs)
289
- img_a.select(fn=lambda s, evt: vote("A", s), inputs=[state], outputs=outputs)
290
- img_b.select(fn=lambda s, evt: vote("B", s), inputs=[state], outputs=outputs)
291
- dataset_dd.change(fn=new_round, inputs=[dataset_dd, rating_dd, state], outputs=outputs)
292
- rating_dd.change(fn=new_round, inputs=[dataset_dd, rating_dd, state], outputs=outputs)
293
- demo.load(fn=lambda s: new_round(list(DATASETS.keys())[0], "safe", s), inputs=[state], outputs=outputs)
 
 
294
 
295
  if __name__ == "__main__":
296
  demo.launch()
 
1
  import gradio as gr
2
  import random
 
 
 
3
  import threading
4
+ import uuid
5
+ import os
6
+ import html
7
+ from pathlib import Path
8
+
9
+ import pandas as pd
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ from storage import VoteStorage
13
+
14
+ LOCAL_DATA_DIR = 'data'
15
+ DEBUG_MODE = os.getenv("DEBUG", "0").lower() in ("1", "true", "yes", "on")
16
+ VOTE_STORAGE = VoteStorage(mode="local" if DEBUG_MODE else "hf", local_dir=LOCAL_DATA_DIR)
17
+
18
+ # -- Pool dataset -----------------------------------------------------------
19
+ if DEBUG_MODE:
20
+ _pool_path = str(Path(__file__).resolve().parent / LOCAL_DATA_DIR / "pool.parquet")
21
+ assert Path(_pool_path).exists(), f"Missing local debug pool file: {_pool_path}"
22
+ else:
23
+ _pool_path = hf_hub_download(
24
+ repo_id="taigasan/e6-visual-ratings",
25
+ filename="pool.parquet",
26
+ repo_type="dataset",
27
+ )
28
+ _pool_df = pd.read_parquet(_pool_path)
29
+ _pool_group_dfs = {g: gdf for g, gdf in _pool_df.groupby("group")}
30
+
31
+ def _pool_fetch_pair(group_name: str) -> tuple:
32
+ gdf = _pool_group_dfs[group_name]
33
+ assert len(gdf) >= 2, f"Not enough rows for group: {group_name}"
34
+ sample = gdf.sample(2)
35
+ return sample.iloc[0], sample.iloc[1]
36
+
37
+
38
+ def _row_image_url(row) -> str:
39
+ sample_url = row.get("sample_url")
40
+ if isinstance(sample_url, str) and sample_url:
 
 
 
 
 
 
41
  return sample_url
42
+ image_url = row.get("image_url")
43
+ if isinstance(image_url, str) and image_url:
44
+ return image_url
45
+ return ''
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  DATASETS: dict[str, dict] = {
48
+ "pool": {
49
+ "fetch_pair": _pool_fetch_pair,
50
+ "get_id": lambda row: row["md5"],
51
+ "get_image": _row_image_url,
52
+ "groups": {g: g for g in sorted(_pool_df["group"].unique())},
 
 
 
 
 
 
 
 
53
  },
54
  }
55
+ DEFAULT_DATASET = list(DATASETS.keys())[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
 
58
  groups = list(cfg["groups"].keys())
59
+ if rating_pref == "all":
60
  return groups
61
+ return [g for g in groups if g.endswith(f"_{rating_pref}")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def _commit_oldest_pending(state: dict):
64
  pending = state.setdefault("pending", [])
65
+ if len(pending) <= 1:
66
  return
67
  oldest = pending.pop(0)
68
+ threading.Thread(target=VOTE_STORAGE.append_vote_row, args=(oldest.copy(), oldest.get("winner")), daemon=True).start()
 
 
 
 
 
 
 
 
69
 
70
  def _render_current(state: dict) -> tuple:
71
+ img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>Image A</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
72
+ img_b_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>Image B</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_b'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
73
+ link_a = f"Image A: https://e621.net/posts/{state['id_a']}"
74
+ link_b = f"Image B: https://e621.net/posts/{state['id_b']}"
75
+ can_go_back = bool(state.get("can_go_back"))
76
+ back_md = "[back](#back)" if can_go_back else "<span class='subtle-back-link-disabled'>back</span>"
77
+ details = f"<span class='subtle-note'>Group: {state['group']}</span>"
78
+ return img_a_html, img_b_html, link_a, link_b, back_md, details, state
79
+
80
+
81
+
82
+ def _normalize_rating_pref(pref: str | None) -> str:
83
+ return pref if pref in ("safe", "all") else "safe"
84
+
85
+
86
+ def _initial_load(state: dict, pref: str | None):
87
+ rating_pref = _normalize_rating_pref(pref)
88
+ return rating_pref, *new_round(DEFAULT_DATASET, rating_pref, state)
89
+
90
+
91
+ def _on_rating_change(rating_pref: str, state: dict):
92
+ rating_pref = _normalize_rating_pref(rating_pref)
93
+ return *new_round(DEFAULT_DATASET, rating_pref, state), rating_pref
94
 
95
  # -- Gradio callbacks -------------------------------------------------------
96
 
97
  def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
98
  cfg = DATASETS[dataset_name]
99
+ groups = _select_groups(cfg, rating_pref)
100
+ assert groups, f"No groups for rating preference: {rating_pref}"
101
+ group = random.choice(groups)
102
+ row_a, row_b = cfg["fetch_pair"](cfg["groups"][group])
103
+ state.setdefault("session_id", uuid.uuid4().hex)
104
  key_a = cfg["get_id"](row_a)
105
  key_b = cfg["get_id"](row_b)
106
+ id_a = int(row_a["id"])
107
+ id_b = int(row_b["id"])
108
+ state.update(dataset=dataset_name, rating_pref=rating_pref, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group)
109
  url_a = cfg["get_image"](row_a)
110
  url_b = cfg["get_image"](row_b)
111
  state["url_a"] = url_a
 
113
  return _render_current(state)
114
 
115
  def _queue_decision(winner: str | None, state: dict):
116
+ assert state.get("session_id"), "Missing session_id: refusing to record vote"
117
  state.setdefault("pending", [])
 
118
  decision = {
119
  "winner": winner,
120
  "key_a": state["key_a"],
121
  "key_b": state["key_b"],
122
+ "id_a": state["id_a"],
123
+ "id_b": state["id_b"],
124
  "url_a": state["url_a"],
125
  "url_b": state["url_b"],
126
  "dataset": state["dataset"],
127
  "rating_pref": state["rating_pref"],
128
+ "group": state["group"],
129
+ "session_id": state["session_id"],
130
  }
131
  state["pending"].append(decision)
132
+ state["last_decision"] = decision
133
+ state["can_go_back"] = True
134
  _commit_oldest_pending(state)
135
 
136
  def vote(winner: str | None, state: dict) -> tuple:
 
139
  return new_round(state["dataset"], state["rating_pref"], state)
140
 
141
  def go_back(state: dict) -> tuple:
 
142
  pending = state.setdefault("pending", [])
143
+ if not state.get("can_go_back"):
144
  return _render_current(state)
145
+ last = state.get("last_decision")
146
+ if not last:
147
+ state["can_go_back"] = False
148
+ return _render_current(state)
149
+ if pending and pending[-1] == last:
150
  pending.pop()
151
+ state["can_go_back"] = False
152
+ state["last_decision"] = None
153
  state.update(
154
  dataset=last["dataset"],
155
  rating_pref=last["rating_pref"],
156
  key_a=last["key_a"],
157
  key_b=last["key_b"],
158
+ id_a=last["id_a"],
159
+ id_b=last["id_b"],
160
  url_a=last["url_a"],
161
  url_b=last["url_b"],
162
+ group=last["group"],
163
  )
164
  return _render_current(state)
165
 
 
 
 
 
166
  # -- UI ---------------------------------------------------------------------
167
 
168
  with gr.Blocks(
169
  title="Image Rater",
170
+ head="""
171
+ <script>
172
+ window.addEventListener('keydown', function (e) {
173
+ const t = e.target;
174
+ if (t && (t.tagName === 'INPUT' || t.tagName === 'TEXTAREA' || t.isContentEditable)) return;
175
+ if (e.key === 'ArrowLeft') {
176
+ e.preventDefault();
177
+ document.querySelector('#btn-vote-a button, button#btn-vote-a')?.click();
178
+ } else if (e.key === 'ArrowRight') {
179
+ e.preventDefault();
180
+ document.querySelector('#btn-vote-b button, button#btn-vote-b')?.click();
181
+ } else if (e.key === 'Backspace') {
182
+ e.preventDefault();
183
+ document.querySelector('#btn-back-action button, button#btn-back-action')?.click();
184
+ }
185
+ });
186
+ document.addEventListener('click', function (e) {
187
+ const a = e.target.closest('a[href="#back"]');
188
+ if (!a) return;
189
+ e.preventDefault();
190
+ document.querySelector('#btn-back-action button, button#btn-back-action')?.click();
191
+ });
192
+ </script>
193
+ """,
194
  css="""
195
  .subtle-link button {
196
  background: none !important;
 
206
  .subtle-link button:hover {
207
  color: #5a5a5a !important;
208
  }
209
+ .subtle-link {
210
+ width: fit-content !important;
211
+ }
212
+ .subtle-link button {
213
+ width: fit-content !important;
214
+ }
215
+ .subtle-note {
216
+ color: #888;
217
+ font-size: 0.9em;
218
+ }
219
+ .rating-card {
220
+ width: 100%;
221
+ }
222
+ .rating-card-title {
223
+ min-height: 24px;
224
+ margin-bottom: 8px;
225
+ }
226
+ .rating-image-frame {
227
+ width: 100%;
228
+ height: 512px;
229
+ border: 1px solid #e6e6e6;
230
+ border-radius: 8px;
231
+ background: #fafafa;
232
+ display: flex;
233
+ align-items: center;
234
+ justify-content: center;
235
+ overflow: hidden;
236
+ }
237
+ .rating-image {
238
+ width: 100%;
239
+ height: 100%;
240
+ object-fit: contain;
241
+ }
242
+ .subtle-back-link-wrap a {
243
+ color: #7a7a7a !important;
244
+ font-size: 0.9em;
245
+ text-decoration: underline;
246
+ }
247
+ .subtle-back-link-wrap a:hover {
248
+ color: #5a5a5a !important;
249
+ }
250
+ .subtle-back-link-disabled {
251
+ color: #b8b8b8 !important;
252
+ pointer-events: none;
253
+ text-decoration: none;
254
+ }
255
+ .hidden-action-btn {
256
+ display: none !important;
257
+ }
258
  """,
259
  ) as demo:
260
+ gr.Markdown("## Image Quality Rater\nRate relative image quality. Choose the image with better quality, or select same quality if they are comparable. Both images are drawn from the same group to avoid cross-group bias.")
261
 
262
  state = gr.State({})
263
+ rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref")
264
 
265
  with gr.Row():
266
+ img_a = gr.HTML()
267
+ img_b = gr.HTML()
268
 
269
  with gr.Row():
270
+ btn_a = gr.Button("👍 Prefer A", variant="primary", elem_id="btn-vote-a")
271
+ btn_skip = gr.Button("Same quality", elem_id="btn-skip")
272
+ btn_b = gr.Button("👍 Prefer B", variant="primary", elem_id="btn-vote-b")
273
 
274
  with gr.Accordion("Settings", open=False):
275
  gr.Markdown("<span style='color:#888;font-size:0.9em;'>Advanced options</span>")
276
  rating_dd = gr.Dropdown(
277
+ choices=["safe", "all"],
278
  value="safe",
279
  label="Rating",
280
+ elem_id="rating-pref",
281
  )
282
+
283
+ link_a = gr.Markdown(label="Image A link")
284
+ link_b = gr.Markdown(label="Image B link")
285
+ back_link = gr.Markdown(elem_classes=["subtle-back-link-wrap"])
286
+ btn_back_action = gr.Button("back", elem_id="btn-back-action", elem_classes=["hidden-action-btn"])
287
+ details_md = gr.Markdown()
288
+ gr.Markdown("<span class='subtle-note'>Dataset: <a href='https://huggingface.co/datasets/taigasan/e6-visual-ratings' target='_blank' rel='noopener noreferrer'>taigasan/e6-visual-ratings</a></span>")
289
+ gr.Markdown("<span class='subtle-note'>Shortcuts: Left = vote A, Right = vote B, Backspace = back</span>")
290
+ outputs = [img_a, img_b, link_a, link_b, back_link, details_md, state]
291
+
292
+ btn_a.click(fn=lambda s: vote("A", s), inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
293
+ btn_b.click(fn=lambda s: vote("B", s), inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
294
+ btn_skip.click(fn=lambda s: vote(None, s), inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
295
+ btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
296
+ rating_dd.change(fn=_on_rating_change, inputs=[rating_dd, state], outputs=[*outputs, rating_pref_store], queue=False, show_progress="hidden")
297
+ demo.load(fn=_initial_load, inputs=[state, rating_pref_store], outputs=[rating_dd, *outputs], queue=False, show_progress="hidden")
298
 
299
  if __name__ == "__main__":
300
  demo.launch()