RedHotTensors commited on
Commit
2e679fc
·
1 Parent(s): ae08ba5

Allow picking from individual image groups.

Browse files
Files changed (2) hide show
  1. app.py +42 -48
  2. explorer.py +6 -0
app.py CHANGED
@@ -250,16 +250,6 @@ DATASETS: dict[str, dict] = {
250
  }
251
  DEFAULT_DATASET = list(DATASETS.keys())[0]
252
 
253
- def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
254
- if rating_pref == "all":
255
- return cfg["groups"]
256
-
257
- return [
258
- g
259
- for g in cfg["groups"]
260
- if g.endswith(f"_{rating_pref}")
261
- ]
262
-
263
  def _format_rating_post_title(post_id: int, votes: int, label: str) -> str:
264
  return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}"
265
 
@@ -279,15 +269,13 @@ def _render_current(state: dict, submit_status: str = "") -> tuple:
279
  def _normalize_rating_pref(pref: str | None) -> str:
280
  return pref if pref in ("safe", "all") else "safe"
281
 
282
-
283
- def _initial_load(state: dict, pref: str | None, submit_key: str | None, image_height: str):
284
- rating_pref = _normalize_rating_pref(pref)
285
  submit_key = _normalize_submit_key(submit_key)
286
- return rating_pref, submit_key, image_height, image_height, *new_round(DEFAULT_DATASET, rating_pref, state)
287
 
288
- def _on_rating_change(rating_pref: str, state: dict):
289
- rating_pref = _normalize_rating_pref(rating_pref)
290
- return *new_round(DEFAULT_DATASET, rating_pref, state), rating_pref
291
 
292
  def _on_image_height_change(image_height: str) -> tuple[str, str]:
293
  return image_height, image_height
@@ -298,7 +286,6 @@ def _normalize_submit_key(submit_key: str | None) -> str:
298
  def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame:
299
  return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0])
300
 
301
-
302
  def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame:
303
  if rating_pref == "all":
304
  rating_filtered = _explorer_df
@@ -308,11 +295,10 @@ def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str)
308
  assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}"
309
  return rating_filtered[rating_filtered["classifier"] == classifier_name]
310
 
311
-
312
- def _load_results(rating_pref_value: str, sort_mode_value: str, classifier_filter_value: str):
313
- rating_pref = _normalize_rating_pref(rating_pref_value)
314
- sort_mode = _normalize_sort_mode(sort_mode_value)
315
- classifier_name = _normalize_classifier_filter(classifier_filter_value)
316
  filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
317
  summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
318
  filtered_explorer_df,
@@ -323,13 +309,11 @@ def _load_results(rating_pref_value: str, sort_mode_value: str, classifier_filte
323
  )
324
  return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset
325
 
326
-
327
  def _normalize_sort_mode(sort_mode: str | None) -> str:
328
  if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"):
329
  return sort_mode
330
  return "Default"
331
 
332
-
333
  def _normalize_classifier_filter(classifier_name: str | None) -> str:
334
  if classifier_name in ALLOWED_CLASSIFIER_FILTERS:
335
  return str(classifier_name)
@@ -337,11 +321,11 @@ def _normalize_classifier_filter(classifier_name: str | None) -> str:
337
 
338
  # -- Gradio callbacks -------------------------------------------------------
339
 
340
- def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
341
- cfg = DATASETS[dataset_name]
342
- groups = _select_groups(cfg, rating_pref)
343
- assert groups, f"No groups for rating preference: {rating_pref}"
344
 
 
345
  group = random.choice(groups)
346
  row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group)
347
 
@@ -352,7 +336,7 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
352
  key_b = cfg["get_id"](row_b)
353
  id_a = int(row_a["id"])
354
  id_b = int(row_b["id"])
355
- state.update(dataset=dataset_name, rating_pref=rating_pref, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason)
356
  url_a = cfg["get_image"](row_a)
357
  url_b = cfg["get_image"](row_b)
358
  state["url_a"] = url_a
@@ -372,7 +356,6 @@ def _queue_decision(winner: str | None, state: dict):
372
  "url_a": state["url_a"],
373
  "url_b": state["url_b"],
374
  "dataset": state["dataset"],
375
- "rating_pref": state["rating_pref"],
376
  "group": state["group"],
377
  "pair_reason": state.get("pair_reason", ""),
378
  "session_id": state["session_id"],
@@ -387,10 +370,13 @@ def _add_vote(idx: int, col_loc: int, delta: int = 1) -> None:
387
  wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]]
388
  _pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1)
389
 
390
- def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
391
  if _normalize_submit_key(submit_key) != SUBMIT_KEY:
392
  return _render_current(state, "Wrong submission key.")
393
 
 
 
 
394
  _queue_decision(winner, state)
395
 
396
  a_idx = _md5_to_idx[state["key_a"]]
@@ -409,7 +395,7 @@ def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
409
  case _:
410
  raise AssertionError
411
 
412
- return new_round(state["dataset"], state["rating_pref"], state)
413
 
414
  def go_back(state: dict) -> tuple:
415
  pending = state.setdefault("pending", [])
@@ -418,7 +404,6 @@ def go_back(state: dict) -> tuple:
418
  last = pending.pop()
419
  state.update(
420
  dataset=last["dataset"],
421
- rating_pref=last["rating_pref"],
422
  key_a=last["key_a"],
423
  key_b=last["key_b"],
424
  id_a=last["id_a"],
@@ -674,6 +659,11 @@ with gr.Blocks(
674
  results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode")
675
  results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier")
676
  image_height_store = gr.BrowserState(default_value=768, storage_key="image_height")
 
 
 
 
 
677
 
678
  with gr.Tabs():
679
  with gr.Tab("Image Quality Rater"):
@@ -693,16 +683,16 @@ with gr.Blocks(
693
  btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b")
694
 
695
  with gr.Accordion("Settings", open=False):
 
 
 
 
 
 
696
  image_height_slider = gr.Slider(
697
  minimum=512, maximum=2048, step=16, precision=0,
698
  label="Image Size",
699
  )
700
- rating_dd = gr.Dropdown(
701
- choices=["safe", "all"],
702
- value="safe",
703
- label="Rating",
704
- elem_id="rating-pref",
705
- )
706
  submit_key_tb = gr.Textbox(
707
  value="",
708
  type="password",
@@ -717,6 +707,7 @@ with gr.Blocks(
717
 
718
  (
719
  results_summary_md,
 
720
  results_sort_dd,
721
  results_classifier_dd,
722
  results_score_distribution_plot,
@@ -740,23 +731,26 @@ with gr.Blocks(
740
  results_page_offset_state,
741
  ]
742
 
743
- btn_a.click(fn=lambda s, k: vote("A", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
744
- btn_b.click(fn=lambda s, k: vote("B", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
745
- btn_skip.click(fn=lambda s, k: vote(None, s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
746
  btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
747
- rating_dd.change(fn=_on_rating_change, inputs=[rating_dd, state], outputs=[*outputs, rating_pref_store], queue=False, show_progress="hidden")
748
- submit_key_tb.input(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
749
  submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
750
- rating_dd.change(fn=_load_results, inputs=[rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
 
 
 
 
751
  results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden")
752
  results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
753
  results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden")
754
  results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden")
755
- image_height_slider.change(fn=_on_image_height_change, inputs=[image_height_slider], outputs=[image_height_store, image_height], queue=False, show_progress="hidden")
756
- demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store], outputs=[rating_dd, submit_key_tb, image_height_slider, image_height, *outputs], queue=False, show_progress="hidden")
757
  demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
758
  demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden")
759
  demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden")
 
760
  results_load_more_btn.click(
761
  fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o),
762
  inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state],
 
250
  }
251
  DEFAULT_DATASET = list(DATASETS.keys())[0]
252
 
 
 
 
 
 
 
 
 
 
 
253
  def _format_rating_post_title(post_id: int, votes: int, label: str) -> str:
254
  return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}"
255
 
 
269
  def _normalize_rating_pref(pref: str | None) -> str:
270
  return pref if pref in ("safe", "all") else "safe"
271
 
272
+ def _initial_load(state: dict, rating_pref: str | None, submit_key: str | None, image_height: str, groups: list[str]):
273
+ rating_pref = _normalize_rating_pref(rating_pref)
 
274
  submit_key = _normalize_submit_key(submit_key)
275
+ return rating_pref, submit_key, image_height, image_height, groups, *new_round(DEFAULT_DATASET, groups, state)
276
 
277
+ def _on_groups_change(groups: list[str], state: dict):
278
+ return *new_round(DEFAULT_DATASET, groups, state), groups
 
279
 
280
  def _on_image_height_change(image_height: str) -> tuple[str, str]:
281
  return image_height, image_height
 
286
  def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame:
287
  return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0])
288
 
 
289
  def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame:
290
  if rating_pref == "all":
291
  rating_filtered = _explorer_df
 
295
  assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}"
296
  return rating_filtered[rating_filtered["classifier"] == classifier_name]
297
 
298
+ def _load_results(rating_pref: str, sort_mode: str, classifier_filter: str):
299
+ rating_pref = _normalize_rating_pref(rating_pref)
300
+ sort_mode = _normalize_sort_mode(sort_mode)
301
+ classifier_name = _normalize_classifier_filter(classifier_filter)
 
302
  filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
303
  summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
304
  filtered_explorer_df,
 
309
  )
310
  return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset
311
 
 
312
  def _normalize_sort_mode(sort_mode: str | None) -> str:
313
  if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"):
314
  return sort_mode
315
  return "Default"
316
 
 
317
  def _normalize_classifier_filter(classifier_name: str | None) -> str:
318
  if classifier_name in ALLOWED_CLASSIFIER_FILTERS:
319
  return str(classifier_name)
 
321
 
322
  # -- Gradio callbacks -------------------------------------------------------
323
 
324
+ def new_round(dataset_name: str, groups: list[str], state: dict) -> tuple:
325
+ if not groups:
326
+ return "", "", gr.skip(), "", "Please select at least one group.", state
 
327
 
328
+ cfg = DATASETS[dataset_name]
329
  group = random.choice(groups)
330
  row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group)
331
 
 
336
  key_b = cfg["get_id"](row_b)
337
  id_a = int(row_a["id"])
338
  id_b = int(row_b["id"])
339
+ state.update(dataset=dataset_name, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason)
340
  url_a = cfg["get_image"](row_a)
341
  url_b = cfg["get_image"](row_b)
342
  state["url_a"] = url_a
 
356
  "url_a": state["url_a"],
357
  "url_b": state["url_b"],
358
  "dataset": state["dataset"],
 
359
  "group": state["group"],
360
  "pair_reason": state.get("pair_reason", ""),
361
  "session_id": state["session_id"],
 
370
  wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]]
371
  _pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1)
372
 
373
+ def vote(winner: str | None, state: dict, groups: list[str], submit_key: str | None) -> tuple:
374
  if _normalize_submit_key(submit_key) != SUBMIT_KEY:
375
  return _render_current(state, "Wrong submission key.")
376
 
377
+ if not groups:
378
+ return "", "", gr.skip(), "", "Please select at least one group.", state
379
+
380
  _queue_decision(winner, state)
381
 
382
  a_idx = _md5_to_idx[state["key_a"]]
 
395
  case _:
396
  raise AssertionError
397
 
398
+ return new_round(state["dataset"], groups, state)
399
 
400
  def go_back(state: dict) -> tuple:
401
  pending = state.setdefault("pending", [])
 
404
  last = pending.pop()
405
  state.update(
406
  dataset=last["dataset"],
 
407
  key_a=last["key_a"],
408
  key_b=last["key_b"],
409
  id_a=last["id_a"],
 
659
  results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode")
660
  results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier")
661
  image_height_store = gr.BrowserState(default_value=768, storage_key="image_height")
662
+ groups_store = gr.BrowserState(default_value=[
663
+ group
664
+ for group in DATASETS[DEFAULT_DATASET]["groups"]
665
+ if group.endswith("_safe")
666
+ ], storage_key="groups")
667
 
668
  with gr.Tabs():
669
  with gr.Tab("Image Quality Rater"):
 
683
  btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b")
684
 
685
  with gr.Accordion("Settings", open=False):
686
+ groups_select = gr.CheckboxGroup(
687
+ choices=DATASETS[DEFAULT_DATASET]["groups"],
688
+ label="Categories",
689
+ show_label=True,
690
+ show_select_all=True
691
+ )
692
  image_height_slider = gr.Slider(
693
  minimum=512, maximum=2048, step=16, precision=0,
694
  label="Image Size",
695
  )
 
 
 
 
 
 
696
  submit_key_tb = gr.Textbox(
697
  value="",
698
  type="password",
 
707
 
708
  (
709
  results_summary_md,
710
+ results_rating_dd,
711
  results_sort_dd,
712
  results_classifier_dd,
713
  results_score_distribution_plot,
 
731
  results_page_offset_state,
732
  ]
733
 
734
+ btn_a.click(fn=lambda s, g, k: vote("A", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
735
+ btn_b.click(fn=lambda s, g, k: vote("B", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
736
+ btn_skip.click(fn=lambda s, g, k: vote(None, s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
737
  btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
 
 
738
  submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
739
+ groups_select.change(fn=_on_groups_change, inputs=[groups_select, state], outputs=[*outputs, groups_store], queue=False, show_progress="hidden")
740
+ image_height_slider.change(fn=_on_image_height_change, inputs=[image_height_slider], outputs=[image_height_store, image_height], queue=False, show_progress="hidden")
741
+
742
+ results_rating_dd.change(fn=_normalize_rating_pref, inputs=[results_rating_dd], outputs=[rating_pref_store], queue=False, show_progress="hidden")
743
+ results_rating_dd.change(fn=_load_results, inputs=[results_rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
744
  results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden")
745
  results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
746
  results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden")
747
  results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden")
748
+
749
+ demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store, groups_store], outputs=[results_rating_dd, submit_key_tb, image_height_slider, image_height, groups_select, *outputs], queue=False, show_progress="hidden")
750
  demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
751
  demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden")
752
  demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden")
753
+
754
  results_load_more_btn.click(
755
  fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o),
756
  inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state],
explorer.py CHANGED
@@ -208,6 +208,11 @@ def add_results_tab(pool_df: pd.DataFrame):
208
  results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more")
209
  selected_image_md = gr.Markdown("Click an image to reveal its ID and link.")
210
  results_score_distribution_plot = gr.Plot(label="Classifier score distribution")
 
 
 
 
 
211
  results_sort_dd = gr.Dropdown(
212
  choices=SORT_MODES,
213
  value="Default",
@@ -225,6 +230,7 @@ def add_results_tab(pool_df: pd.DataFrame):
225
  results_page_offset_state = gr.State(0)
226
  return (
227
  results_summary_md,
 
228
  results_sort_dd,
229
  results_classifier_dd,
230
  results_score_distribution_plot,
 
208
  results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more")
209
  selected_image_md = gr.Markdown("Click an image to reveal its ID and link.")
210
  results_score_distribution_plot = gr.Plot(label="Classifier score distribution")
211
+ results_rating_dd = gr.Dropdown(
212
+ choices=["safe", "all"],
213
+ value="safe",
214
+ label="Rating",
215
+ )
216
  results_sort_dd = gr.Dropdown(
217
  choices=SORT_MODES,
218
  value="Default",
 
230
  results_page_offset_state = gr.State(0)
231
  return (
232
  results_summary_md,
233
+ results_rating_dd,
234
  results_sort_dd,
235
  results_classifier_dd,
236
  results_score_distribution_plot,