Spaces:
Running
Running
Commit ·
2e679fc
1
Parent(s): ae08ba5
Allow picking from individual image groups.
Browse files- app.py +42 -48
- explorer.py +6 -0
app.py
CHANGED
|
@@ -250,16 +250,6 @@ DATASETS: dict[str, dict] = {
|
|
| 250 |
}
|
| 251 |
DEFAULT_DATASET = list(DATASETS.keys())[0]
|
| 252 |
|
| 253 |
-
def _select_groups(cfg: dict, rating_pref: str) -> list[str]:
|
| 254 |
-
if rating_pref == "all":
|
| 255 |
-
return cfg["groups"]
|
| 256 |
-
|
| 257 |
-
return [
|
| 258 |
-
g
|
| 259 |
-
for g in cfg["groups"]
|
| 260 |
-
if g.endswith(f"_{rating_pref}")
|
| 261 |
-
]
|
| 262 |
-
|
| 263 |
def _format_rating_post_title(post_id: int, votes: int, label: str) -> str:
|
| 264 |
return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}"
|
| 265 |
|
|
@@ -279,15 +269,13 @@ def _render_current(state: dict, submit_status: str = "") -> tuple:
|
|
| 279 |
def _normalize_rating_pref(pref: str | None) -> str:
|
| 280 |
return pref if pref in ("safe", "all") else "safe"
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
rating_pref = _normalize_rating_pref(pref)
|
| 285 |
submit_key = _normalize_submit_key(submit_key)
|
| 286 |
-
return rating_pref, submit_key, image_height, image_height, *new_round(DEFAULT_DATASET,
|
| 287 |
|
| 288 |
-
def
|
| 289 |
-
|
| 290 |
-
return *new_round(DEFAULT_DATASET, rating_pref, state), rating_pref
|
| 291 |
|
| 292 |
def _on_image_height_change(image_height: str) -> tuple[str, str]:
|
| 293 |
return image_height, image_height
|
|
@@ -298,7 +286,6 @@ def _normalize_submit_key(submit_key: str | None) -> str:
|
|
| 298 |
def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame:
|
| 299 |
return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0])
|
| 300 |
|
| 301 |
-
|
| 302 |
def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame:
|
| 303 |
if rating_pref == "all":
|
| 304 |
rating_filtered = _explorer_df
|
|
@@ -308,11 +295,10 @@ def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str)
|
|
| 308 |
assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}"
|
| 309 |
return rating_filtered[rating_filtered["classifier"] == classifier_name]
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
classifier_name = _normalize_classifier_filter(classifier_filter_value)
|
| 316 |
filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
|
| 317 |
summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
|
| 318 |
filtered_explorer_df,
|
|
@@ -323,13 +309,11 @@ def _load_results(rating_pref_value: str, sort_mode_value: str, classifier_filte
|
|
| 323 |
)
|
| 324 |
return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset
|
| 325 |
|
| 326 |
-
|
| 327 |
def _normalize_sort_mode(sort_mode: str | None) -> str:
|
| 328 |
if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"):
|
| 329 |
return sort_mode
|
| 330 |
return "Default"
|
| 331 |
|
| 332 |
-
|
| 333 |
def _normalize_classifier_filter(classifier_name: str | None) -> str:
|
| 334 |
if classifier_name in ALLOWED_CLASSIFIER_FILTERS:
|
| 335 |
return str(classifier_name)
|
|
@@ -337,11 +321,11 @@ def _normalize_classifier_filter(classifier_name: str | None) -> str:
|
|
| 337 |
|
| 338 |
# -- Gradio callbacks -------------------------------------------------------
|
| 339 |
|
| 340 |
-
def new_round(dataset_name: str,
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
assert groups, f"No groups for rating preference: {rating_pref}"
|
| 344 |
|
|
|
|
| 345 |
group = random.choice(groups)
|
| 346 |
row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group)
|
| 347 |
|
|
@@ -352,7 +336,7 @@ def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple:
|
|
| 352 |
key_b = cfg["get_id"](row_b)
|
| 353 |
id_a = int(row_a["id"])
|
| 354 |
id_b = int(row_b["id"])
|
| 355 |
-
state.update(dataset=dataset_name,
|
| 356 |
url_a = cfg["get_image"](row_a)
|
| 357 |
url_b = cfg["get_image"](row_b)
|
| 358 |
state["url_a"] = url_a
|
|
@@ -372,7 +356,6 @@ def _queue_decision(winner: str | None, state: dict):
|
|
| 372 |
"url_a": state["url_a"],
|
| 373 |
"url_b": state["url_b"],
|
| 374 |
"dataset": state["dataset"],
|
| 375 |
-
"rating_pref": state["rating_pref"],
|
| 376 |
"group": state["group"],
|
| 377 |
"pair_reason": state.get("pair_reason", ""),
|
| 378 |
"session_id": state["session_id"],
|
|
@@ -387,10 +370,13 @@ def _add_vote(idx: int, col_loc: int, delta: int = 1) -> None:
|
|
| 387 |
wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]]
|
| 388 |
_pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1)
|
| 389 |
|
| 390 |
-
def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
|
| 391 |
if _normalize_submit_key(submit_key) != SUBMIT_KEY:
|
| 392 |
return _render_current(state, "Wrong submission key.")
|
| 393 |
|
|
|
|
|
|
|
|
|
|
| 394 |
_queue_decision(winner, state)
|
| 395 |
|
| 396 |
a_idx = _md5_to_idx[state["key_a"]]
|
|
@@ -409,7 +395,7 @@ def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple:
|
|
| 409 |
case _:
|
| 410 |
raise AssertionError
|
| 411 |
|
| 412 |
-
return new_round(state["dataset"],
|
| 413 |
|
| 414 |
def go_back(state: dict) -> tuple:
|
| 415 |
pending = state.setdefault("pending", [])
|
|
@@ -418,7 +404,6 @@ def go_back(state: dict) -> tuple:
|
|
| 418 |
last = pending.pop()
|
| 419 |
state.update(
|
| 420 |
dataset=last["dataset"],
|
| 421 |
-
rating_pref=last["rating_pref"],
|
| 422 |
key_a=last["key_a"],
|
| 423 |
key_b=last["key_b"],
|
| 424 |
id_a=last["id_a"],
|
|
@@ -674,6 +659,11 @@ with gr.Blocks(
|
|
| 674 |
results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode")
|
| 675 |
results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier")
|
| 676 |
image_height_store = gr.BrowserState(default_value=768, storage_key="image_height")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
with gr.Tabs():
|
| 679 |
with gr.Tab("Image Quality Rater"):
|
|
@@ -693,16 +683,16 @@ with gr.Blocks(
|
|
| 693 |
btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b")
|
| 694 |
|
| 695 |
with gr.Accordion("Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
image_height_slider = gr.Slider(
|
| 697 |
minimum=512, maximum=2048, step=16, precision=0,
|
| 698 |
label="Image Size",
|
| 699 |
)
|
| 700 |
-
rating_dd = gr.Dropdown(
|
| 701 |
-
choices=["safe", "all"],
|
| 702 |
-
value="safe",
|
| 703 |
-
label="Rating",
|
| 704 |
-
elem_id="rating-pref",
|
| 705 |
-
)
|
| 706 |
submit_key_tb = gr.Textbox(
|
| 707 |
value="",
|
| 708 |
type="password",
|
|
@@ -717,6 +707,7 @@ with gr.Blocks(
|
|
| 717 |
|
| 718 |
(
|
| 719 |
results_summary_md,
|
|
|
|
| 720 |
results_sort_dd,
|
| 721 |
results_classifier_dd,
|
| 722 |
results_score_distribution_plot,
|
|
@@ -740,23 +731,26 @@ with gr.Blocks(
|
|
| 740 |
results_page_offset_state,
|
| 741 |
]
|
| 742 |
|
| 743 |
-
btn_a.click(fn=lambda s, k: vote("A", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 744 |
-
btn_b.click(fn=lambda s, k: vote("B", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 745 |
-
btn_skip.click(fn=lambda s, k: vote(None, s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 746 |
btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
| 747 |
-
rating_dd.change(fn=_on_rating_change, inputs=[rating_dd, state], outputs=[*outputs, rating_pref_store], queue=False, show_progress="hidden")
|
| 748 |
-
submit_key_tb.input(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
|
| 749 |
submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
|
| 750 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden")
|
| 752 |
results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 753 |
results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden")
|
| 754 |
results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 755 |
-
|
| 756 |
-
demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store], outputs=[
|
| 757 |
demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 758 |
demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden")
|
| 759 |
demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden")
|
|
|
|
| 760 |
results_load_more_btn.click(
|
| 761 |
fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o),
|
| 762 |
inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state],
|
|
|
|
| 250 |
}
|
| 251 |
DEFAULT_DATASET = list(DATASETS.keys())[0]
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
def _format_rating_post_title(post_id: int, votes: int, label: str) -> str:
|
| 254 |
return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}"
|
| 255 |
|
|
|
|
| 269 |
def _normalize_rating_pref(pref: str | None) -> str:
|
| 270 |
return pref if pref in ("safe", "all") else "safe"
|
| 271 |
|
| 272 |
+
def _initial_load(state: dict, rating_pref: str | None, submit_key: str | None, image_height: str, groups: list[str]):
|
| 273 |
+
rating_pref = _normalize_rating_pref(rating_pref)
|
|
|
|
| 274 |
submit_key = _normalize_submit_key(submit_key)
|
| 275 |
+
return rating_pref, submit_key, image_height, image_height, groups, *new_round(DEFAULT_DATASET, groups, state)
|
| 276 |
|
| 277 |
+
def _on_groups_change(groups: list[str], state: dict):
|
| 278 |
+
return *new_round(DEFAULT_DATASET, groups, state), groups
|
|
|
|
| 279 |
|
| 280 |
def _on_image_height_change(image_height: str) -> tuple[str, str]:
|
| 281 |
return image_height, image_height
|
|
|
|
| 286 |
def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame:
|
| 287 |
return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0])
|
| 288 |
|
|
|
|
| 289 |
def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame:
|
| 290 |
if rating_pref == "all":
|
| 291 |
rating_filtered = _explorer_df
|
|
|
|
| 295 |
assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}"
|
| 296 |
return rating_filtered[rating_filtered["classifier"] == classifier_name]
|
| 297 |
|
| 298 |
+
def _load_results(rating_pref: str, sort_mode: str, classifier_filter: str):
|
| 299 |
+
rating_pref = _normalize_rating_pref(rating_pref)
|
| 300 |
+
sort_mode = _normalize_sort_mode(sort_mode)
|
| 301 |
+
classifier_name = _normalize_classifier_filter(classifier_filter)
|
|
|
|
| 302 |
filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
|
| 303 |
summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
|
| 304 |
filtered_explorer_df,
|
|
|
|
| 309 |
)
|
| 310 |
return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset
|
| 311 |
|
|
|
|
| 312 |
def _normalize_sort_mode(sort_mode: str | None) -> str:
|
| 313 |
if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"):
|
| 314 |
return sort_mode
|
| 315 |
return "Default"
|
| 316 |
|
|
|
|
| 317 |
def _normalize_classifier_filter(classifier_name: str | None) -> str:
|
| 318 |
if classifier_name in ALLOWED_CLASSIFIER_FILTERS:
|
| 319 |
return str(classifier_name)
|
|
|
|
| 321 |
|
| 322 |
# -- Gradio callbacks -------------------------------------------------------
|
| 323 |
|
| 324 |
+
def new_round(dataset_name: str, groups: list[str], state: dict) -> tuple:
|
| 325 |
+
if not groups:
|
| 326 |
+
return "", "", gr.skip(), "", "Please select at least one group.", state
|
|
|
|
| 327 |
|
| 328 |
+
cfg = DATASETS[dataset_name]
|
| 329 |
group = random.choice(groups)
|
| 330 |
row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group)
|
| 331 |
|
|
|
|
| 336 |
key_b = cfg["get_id"](row_b)
|
| 337 |
id_a = int(row_a["id"])
|
| 338 |
id_b = int(row_b["id"])
|
| 339 |
+
state.update(dataset=dataset_name, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason)
|
| 340 |
url_a = cfg["get_image"](row_a)
|
| 341 |
url_b = cfg["get_image"](row_b)
|
| 342 |
state["url_a"] = url_a
|
|
|
|
| 356 |
"url_a": state["url_a"],
|
| 357 |
"url_b": state["url_b"],
|
| 358 |
"dataset": state["dataset"],
|
|
|
|
| 359 |
"group": state["group"],
|
| 360 |
"pair_reason": state.get("pair_reason", ""),
|
| 361 |
"session_id": state["session_id"],
|
|
|
|
| 370 |
wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]]
|
| 371 |
_pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1)
|
| 372 |
|
| 373 |
+
def vote(winner: str | None, state: dict, groups: list[str], submit_key: str | None) -> tuple:
|
| 374 |
if _normalize_submit_key(submit_key) != SUBMIT_KEY:
|
| 375 |
return _render_current(state, "Wrong submission key.")
|
| 376 |
|
| 377 |
+
if not groups:
|
| 378 |
+
return "", "", gr.skip(), "", "Please select at least one group.", state
|
| 379 |
+
|
| 380 |
_queue_decision(winner, state)
|
| 381 |
|
| 382 |
a_idx = _md5_to_idx[state["key_a"]]
|
|
|
|
| 395 |
case _:
|
| 396 |
raise AssertionError
|
| 397 |
|
| 398 |
+
return new_round(state["dataset"], groups, state)
|
| 399 |
|
| 400 |
def go_back(state: dict) -> tuple:
|
| 401 |
pending = state.setdefault("pending", [])
|
|
|
|
| 404 |
last = pending.pop()
|
| 405 |
state.update(
|
| 406 |
dataset=last["dataset"],
|
|
|
|
| 407 |
key_a=last["key_a"],
|
| 408 |
key_b=last["key_b"],
|
| 409 |
id_a=last["id_a"],
|
|
|
|
| 659 |
results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode")
|
| 660 |
results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier")
|
| 661 |
image_height_store = gr.BrowserState(default_value=768, storage_key="image_height")
|
| 662 |
+
groups_store = gr.BrowserState(default_value=[
|
| 663 |
+
group
|
| 664 |
+
for group in DATASETS[DEFAULT_DATASET]["groups"]
|
| 665 |
+
if group.endswith("_safe")
|
| 666 |
+
], storage_key="groups")
|
| 667 |
|
| 668 |
with gr.Tabs():
|
| 669 |
with gr.Tab("Image Quality Rater"):
|
|
|
|
| 683 |
btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b")
|
| 684 |
|
| 685 |
with gr.Accordion("Settings", open=False):
|
| 686 |
+
groups_select = gr.CheckboxGroup(
|
| 687 |
+
choices=DATASETS[DEFAULT_DATASET]["groups"],
|
| 688 |
+
label="Categories",
|
| 689 |
+
show_label=True,
|
| 690 |
+
show_select_all=True
|
| 691 |
+
)
|
| 692 |
image_height_slider = gr.Slider(
|
| 693 |
minimum=512, maximum=2048, step=16, precision=0,
|
| 694 |
label="Image Size",
|
| 695 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
submit_key_tb = gr.Textbox(
|
| 697 |
value="",
|
| 698 |
type="password",
|
|
|
|
| 707 |
|
| 708 |
(
|
| 709 |
results_summary_md,
|
| 710 |
+
results_rating_dd,
|
| 711 |
results_sort_dd,
|
| 712 |
results_classifier_dd,
|
| 713 |
results_score_distribution_plot,
|
|
|
|
| 731 |
results_page_offset_state,
|
| 732 |
]
|
| 733 |
|
| 734 |
+
btn_a.click(fn=lambda s, g, k: vote("A", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 735 |
+
btn_b.click(fn=lambda s, g, k: vote("B", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 736 |
+
btn_skip.click(fn=lambda s, g, k: vote(None, s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
|
| 737 |
btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
|
|
|
|
|
|
|
| 738 |
submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
|
| 739 |
+
groups_select.change(fn=_on_groups_change, inputs=[groups_select, state], outputs=[*outputs, groups_store], queue=False, show_progress="hidden")
|
| 740 |
+
image_height_slider.change(fn=_on_image_height_change, inputs=[image_height_slider], outputs=[image_height_store, image_height], queue=False, show_progress="hidden")
|
| 741 |
+
|
| 742 |
+
results_rating_dd.change(fn=_normalize_rating_pref, inputs=[results_rating_dd], outputs=[rating_pref_store], queue=False, show_progress="hidden")
|
| 743 |
+
results_rating_dd.change(fn=_load_results, inputs=[results_rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 744 |
results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden")
|
| 745 |
results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 746 |
results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden")
|
| 747 |
results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 748 |
+
|
| 749 |
+
demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store, groups_store], outputs=[results_rating_dd, submit_key_tb, image_height_slider, image_height, groups_select, *outputs], queue=False, show_progress="hidden")
|
| 750 |
demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
|
| 751 |
demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden")
|
| 752 |
demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden")
|
| 753 |
+
|
| 754 |
results_load_more_btn.click(
|
| 755 |
fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o),
|
| 756 |
inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state],
|
explorer.py
CHANGED
|
@@ -208,6 +208,11 @@ def add_results_tab(pool_df: pd.DataFrame):
|
|
| 208 |
results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more")
|
| 209 |
selected_image_md = gr.Markdown("Click an image to reveal its ID and link.")
|
| 210 |
results_score_distribution_plot = gr.Plot(label="Classifier score distribution")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
results_sort_dd = gr.Dropdown(
|
| 212 |
choices=SORT_MODES,
|
| 213 |
value="Default",
|
|
@@ -225,6 +230,7 @@ def add_results_tab(pool_df: pd.DataFrame):
|
|
| 225 |
results_page_offset_state = gr.State(0)
|
| 226 |
return (
|
| 227 |
results_summary_md,
|
|
|
|
| 228 |
results_sort_dd,
|
| 229 |
results_classifier_dd,
|
| 230 |
results_score_distribution_plot,
|
|
|
|
| 208 |
results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more")
|
| 209 |
selected_image_md = gr.Markdown("Click an image to reveal its ID and link.")
|
| 210 |
results_score_distribution_plot = gr.Plot(label="Classifier score distribution")
|
| 211 |
+
results_rating_dd = gr.Dropdown(
|
| 212 |
+
choices=["safe", "all"],
|
| 213 |
+
value="safe",
|
| 214 |
+
label="Rating",
|
| 215 |
+
)
|
| 216 |
results_sort_dd = gr.Dropdown(
|
| 217 |
choices=SORT_MODES,
|
| 218 |
value="Default",
|
|
|
|
| 230 |
results_page_offset_state = gr.State(0)
|
| 231 |
return (
|
| 232 |
results_summary_md,
|
| 233 |
+
results_rating_dd,
|
| 234 |
results_sort_dd,
|
| 235 |
results_classifier_dd,
|
| 236 |
results_score_distribution_plot,
|