RedHotTensors's picture
Improve diversity by using weighted sampling instead of discrete bands.
e2b56f8
import gradio as gr
import random
import threading
import time
import uuid
import os
import html
import sys
from typing import Callable
import pandas as pd
from huggingface_hub import hf_hub_download
from storage import VoteStorage
from stats_from_logs import load_stats_by_md5
from explorer import ALLOWED_CLASSIFIER_FILTERS, add_results_tab, build_results_data, load_more_results, on_gallery_select
DEBUG_MODE = os.getenv("DEBUG", "0").lower() in ("1", "true", "yes", "on")
RATINGS_APP_TOKEN = os.getenv("RATINGS_APP_TOKEN")
SUBMIT_KEY = os.getenv("RATINGS_SUBMIT_KEY")
assert SUBMIT_KEY, "Missing required env var: RATINGS_SUBMIT_KEY"
POOL_REPO_ID = "taigasan/e6-visual-ratings"
VOTE_STORAGE = VoteStorage(mode="void" if DEBUG_MODE else "hf", token=RATINGS_APP_TOKEN)
STATS_RELOAD_S = 30 * 60
# -- Pool dataset -----------------------------------------------------------
_pool_path = hf_hub_download(
repo_id=POOL_REPO_ID,
filename="pool.parquet",
repo_type="dataset",
token=RATINGS_APP_TOKEN
)
_pool_df = pd.read_parquet(_pool_path)
_pool_df[["wins", "losses", "ties", "votes", "winrate"]] = (0, 0, 0, 0, 0.0)
WINS_LOC = _pool_df.columns.get_loc("wins")
LOSSES_LOC = _pool_df.columns.get_loc("losses")
TIES_LOC = _pool_df.columns.get_loc("ties")
VOTES_LOC = _pool_df.columns.get_loc("votes")
WINRATE_LOC = _pool_df.columns.get_loc("winrate")
_md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) }
_pool_lock = threading.Lock()
_stats_last_loaded_at = 0.0
_explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
def _load_stats() -> None:
VOTE_STORAGE.sync()
load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN)
n_missing = 0
with _pool_lock:
VOTE_STORAGE.sync()
stats_by_key = load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN)
for md5, stats in stats_by_key.items():
if (idx := _md5_to_idx.get(md5)) is not None:
_pool_df.iloc[idx, [WINS_LOC, LOSSES_LOC, TIES_LOC, VOTES_LOC, WINRATE_LOC]] = (
stats.wins, stats.losses, stats.ties, stats.votes, stats.winrate
)
else:
n_missing += 1
if n_missing:
print(f"{n_missing} md5s have stats but are not in the pool!", file=sys.stderr)
classifier_scores_path = hf_hub_download(
repo_id=POOL_REPO_ID,
filename="classifier_scores.parquet",
repo_type="dataset",
token=RATINGS_APP_TOKEN,
)
validation_set_path = hf_hub_download(
repo_id=POOL_REPO_ID,
filename="validation_set.parquet",
repo_type="dataset",
token=RATINGS_APP_TOKEN,
)
validation_df = pd.read_parquet(
validation_set_path,
columns=["group", "id", "md5", "rating", "sample_url", "image_url"],
)
classifier_scores_df = pd.read_parquet(classifier_scores_path)
assert {"classifier", "md5", "classifier_score", "percentile"}.issubset(classifier_scores_df.columns), "classifier_scores.parquet missing expected columns"
classifier_scores_df = classifier_scores_df[["classifier", "md5", "classifier_score", "percentile"]]
classifier_scores_df["classifier"] = classifier_scores_df["classifier"].astype(str)
classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str)
validation_df["md5"] = validation_df["md5"].astype(str)
global _explorer_df
_explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many")
def _stats_reloader() -> None:
while True:
time.sleep(STATS_RELOAD_S)
_load_stats()
_load_stats()
threading.Thread(target=_stats_reloader, daemon=True).start()
def _pick_from(df: pd.DataFrame, *, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None:
if len(df) < 2:
return None
sample = df.sample(2, weights=weights, replace=False)
return sample.iloc[0], sample.iloc[1], len(df)
def _pick_similar(
df: pd.DataFrame,
distance: Callable[[pd.DataFrame, pd.Series], pd.Series],
*,
weights: Callable[[pd.DataFrame], pd.Series] | None = None,
other_df: pd.DataFrame | None = None,
) -> tuple[pd.Series, pd.Series, int] | None:
if len(df) < 2:
return None
if other_df is None:
other_df = df
elif len(other_df) < 2:
return None
weight_vals: pd.Series | None = None
if weights is not None:
weight_vals = weights(df)
first = df.sample(weights=weight_vals).iloc[0]
weight_vals = 1.0 / (1.0 + distance(other_df, first))
while True:
other = other_df.sample(weights=weight_vals).iloc[0]
if other["md5"] != first["md5"]:
return first, other, len(df)
def _pool_fetch_pair(group: str) -> tuple[pd.Series, pd.Series, int, str]:
gdf = _pool_df[_pool_df["group"] == group]
voted = gdf[gdf["votes"] > 0]
votes = voted["votes"]
# Pair first-time winners.
picked = _pick_from(voted[(votes == 1) & (voted["wins"] == 1)])
if picked is not None:
return *picked, "new-winners"
# Pair first-time losers.
picked = _pick_from(voted[(votes == 1) & (voted["losses"] == 1)])
if picked is not None:
return *picked, "new-losers"
def record_distance(df: pd.DataFrame, pivot: pd.Series) -> pd.Series:
return (
(df["wins"] - pivot["wins"])**2 +
(df["losses"] - pivot["losses"])**2
)**0.75 # L2 is a bit too loose
# Link cliques to main network and break ties.
nonties = votes - voted["ties"]
picked = _pick_similar(
voted[(nonties == 0) | (votes == 2)],
record_distance,
other_df=voted[nonties > 3],
)
if picked is not None:
return *picked, "sparse"
# Introduce new images.
if len(voted) < 8 or random.random() < 0.33:
unvoted = gdf[gdf["votes"] == 0]
match len(unvoted):
case 0:
pass
case 1:
return unvoted.iloc[0], voted.iloc[0], 1, "new"
case _:
picked = _pick_from(unvoted)
assert picked is not None
return *picked, "new"
# Vote-weighted random sampling between similar winrates, slighlty biased against picking losers.
picked = _pick_similar(
voted, record_distance,
weights=lambda df: 1.0 / (df["votes"]**1.25 + 0.1 * df["losses"]),
)
assert picked is not None
return *picked, "fair-probe"
def _row_image_url(row) -> str:
sample_url = row.get("sample_url")
if isinstance(sample_url, str) and sample_url:
return sample_url
image_url = row.get("image_url")
if isinstance(image_url, str) and image_url:
return image_url
return ''
DATASETS: dict[str, dict] = {
"pool": {
"fetch_pair": _pool_fetch_pair,
"get_id": lambda row: row["md5"],
"get_image": _row_image_url,
"groups": sorted(_pool_df["group"].unique()),
},
}
DEFAULT_DATASET = list(DATASETS.keys())[0]
def _format_rating_post_title(post_id: int, votes: int, label: str) -> str:
return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}"
def _render_current(state: dict, submit_status: str = "") -> tuple:
votes_a = _pool_df.iloc[_md5_to_idx[state["key_a"]], VOTES_LOC]
votes_b = _pool_df.iloc[_md5_to_idx[state["key_b"]], VOTES_LOC]
title_a = _format_rating_post_title(state["id_a"], votes_a, "Image A")
title_b = _format_rating_post_title(state["id_b"], votes_b, "Image B")
img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\">{title_a}</div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
img_b_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\">{title_b}</div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_b'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>"
can_go_back = bool(state.get("pending", ()))
pair_details = f"/ {state['group']} / {state.get('pair_reason', 'unknown')}"
return img_a_html, img_b_html, gr.Button(interactive=can_go_back), html.escape(pair_details), html.escape(submit_status), state
def _normalize_rating_pref(pref: str | None) -> str:
return pref if pref in ("safe", "all") else "safe"
def _initial_load(state: dict, rating_pref: str | None, submit_key: str | None, image_height: str, groups: list[str]):
rating_pref = _normalize_rating_pref(rating_pref)
submit_key = _normalize_submit_key(submit_key)
return rating_pref, submit_key, image_height, image_height, groups, *new_round(DEFAULT_DATASET, groups, state)
def _on_groups_change(groups: list[str], state: dict):
return *new_round(DEFAULT_DATASET, groups, state), groups
def _on_image_height_change(image_height: str) -> tuple[str, str]:
return image_height, image_height
def _normalize_submit_key(submit_key: str | None) -> str:
return (submit_key or "").strip()
def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame:
return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0])
def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame:
if rating_pref == "all":
rating_filtered = _explorer_df
else:
assert rating_pref == "safe", f"Unsupported rating preference: {rating_pref}"
rating_filtered = _explorer_df[_explorer_df["rating"] == "s"]
assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}"
return rating_filtered[rating_filtered["classifier"] == classifier_name]
def _load_results(rating_pref: str, sort_mode: str, classifier_filter: str):
rating_pref = _normalize_rating_pref(rating_pref)
sort_mode = _normalize_sort_mode(sort_mode)
classifier_name = _normalize_classifier_filter(classifier_filter)
filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
filtered_explorer_df,
_explorer_df,
rating_pref,
sort_mode,
classifier_name,
)
return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset
def _normalize_sort_mode(sort_mode: str | None) -> str:
if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"):
return sort_mode
return "Default"
def _normalize_classifier_filter(classifier_name: str | None) -> str:
if classifier_name in ALLOWED_CLASSIFIER_FILTERS:
return str(classifier_name)
return ALLOWED_CLASSIFIER_FILTERS[0]
# -- Gradio callbacks -------------------------------------------------------
def new_round(dataset_name: str, groups: list[str], state: dict) -> tuple:
if not groups:
return "", "", gr.skip(), "", "Please select at least one group.", state
cfg = DATASETS[dataset_name]
group = random.choice(groups)
row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group)
pair_reason = f"{pair_reason} ({reason_remaining})"
state.setdefault("session_id", uuid.uuid4().hex)
key_a = cfg["get_id"](row_a)
key_b = cfg["get_id"](row_b)
id_a = int(row_a["id"])
id_b = int(row_b["id"])
state.update(dataset=dataset_name, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason)
url_a = cfg["get_image"](row_a)
url_b = cfg["get_image"](row_b)
state["url_a"] = url_a
state["url_b"] = url_b
return _render_current(state)
def _queue_decision(winner: str | None, state: dict):
assert state.get("session_id"), "Missing session_id: refusing to record vote"
pending = state.setdefault("pending", [])
pending.append({
"winner": winner,
"key_a": state["key_a"],
"key_b": state["key_b"],
"id_a": state["id_a"],
"id_b": state["id_b"],
"url_a": state["url_a"],
"url_b": state["url_b"],
"dataset": state["dataset"],
"group": state["group"],
"pair_reason": state.get("pair_reason", ""),
"session_id": state["session_id"],
})
if len(pending) > 1:
VOTE_STORAGE.queue_row(pending.pop(0))
def _add_vote(idx: int, col_loc: int, delta: int = 1) -> None:
_pool_df.iloc[idx, [col_loc, VOTES_LOC]] += delta
wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]]
_pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1)
def vote(winner: str | None, state: dict, groups: list[str], submit_key: str | None) -> tuple:
if _normalize_submit_key(submit_key) != SUBMIT_KEY:
return _render_current(state, "Wrong submission key.")
if not groups:
return "", "", gr.skip(), "", "Please select at least one group.", state
_queue_decision(winner, state)
a_idx = _md5_to_idx[state["key_a"]]
b_idx = _md5_to_idx[state["key_b"]]
with _pool_lock:
match winner:
case "A":
_add_vote(a_idx, WINS_LOC)
_add_vote(b_idx, LOSSES_LOC)
case "B":
_add_vote(a_idx, LOSSES_LOC)
_add_vote(b_idx, WINS_LOC)
case None:
_add_vote(a_idx, TIES_LOC)
_add_vote(b_idx, TIES_LOC)
case _:
raise AssertionError
return new_round(state["dataset"], groups, state)
def go_back(state: dict) -> tuple:
pending = state.setdefault("pending", [])
if pending:
last = pending.pop()
state.update(
dataset=last["dataset"],
key_a=last["key_a"],
key_b=last["key_b"],
id_a=last["id_a"],
id_b=last["id_b"],
url_a=last["url_a"],
url_b=last["url_b"],
group=last["group"],
pair_reason=last.get("pair_reason", ""),
)
a_idx = _md5_to_idx[state["key_a"]]
b_idx = _md5_to_idx[state["key_b"]]
with _pool_lock:
match last["winner"]:
case "A":
_add_vote(a_idx, WINS_LOC, -1)
_add_vote(b_idx, LOSSES_LOC, -1)
case "B":
_add_vote(a_idx, LOSSES_LOC, -1)
_add_vote(b_idx, WINS_LOC, -1)
case None:
_add_vote(a_idx, TIES_LOC, -1)
_add_vote(b_idx, TIES_LOC, -1)
case _:
raise AssertionError
return _render_current(state)
# -- UI ---------------------------------------------------------------------
with gr.Blocks(
title="e621 Visual Ratings",
head="""
<script>
const VOTE_COOLDOWN_MS = 1500;
let lastVoteAtMs = 0;
let voteToastTimer = null;
function showVoteToast(message) {
let toast = document.getElementById('vote-cooldown-toast');
if (!toast) {
toast = document.createElement('div');
toast.id = 'vote-cooldown-toast';
toast.style.position = 'fixed';
toast.style.left = '50%';
toast.style.bottom = '20px';
toast.style.transform = 'translateX(-50%)';
toast.style.background = 'rgba(20, 20, 20, 0.92)';
toast.style.color = '#fff';
toast.style.padding = '8px 12px';
toast.style.borderRadius = '8px';
toast.style.fontSize = '0.92rem';
toast.style.zIndex = '9999';
toast.style.pointerEvents = 'none';
toast.style.opacity = '0';
toast.style.transition = 'opacity 120ms ease';
document.body.appendChild(toast);
}
toast.textContent = message;
toast.style.opacity = '1';
if (voteToastTimer) clearTimeout(voteToastTimer);
voteToastTimer = setTimeout(function () {
toast.style.opacity = '0';
}, 1400);
}
function showVoteBlockedMessage(remainingMs) {
const remainingS = Math.max(0.1, remainingMs / 1000).toFixed(1);
showVoteToast(`Please wait ${remainingS}s before submitting again.`);
}
function findVoteButtonTarget(target) {
return target?.closest('#btn-vote-a button, button#btn-vote-a, #btn-vote-b button, button#btn-vote-b, #btn-skip button, button#btn-skip');
}
function clearImageContainers() {
const leftImg = document.querySelector('#img-a img');
const rightImg = document.querySelector('#img-b img');
if (leftImg) {
leftImg.src = '';
leftImg.removeAttribute('srcset');
}
if (rightImg) {
rightImg.src = '';
rightImg.removeAttribute('srcset');
}
}
function isVisible(el) {
return !!(el && el.offsetParent !== null);
}
window.addEventListener('keydown', function (e) {
const t = e.target;
const voteAButton = document.querySelector('#btn-vote-a button, button#btn-vote-a');
const voteBButton = document.querySelector('#btn-vote-b button, button#btn-vote-b');
const skipButton = document.querySelector('#btn-skip button, button#btn-skip');
const backButton = document.querySelector('#btn-back-action button, button#btn-back-action');
const resultsLoadMoreButton = document.querySelector('#btn-results-load-more button, button#btn-results-load-more');
const ratingTabActive = isVisible(voteAButton) || isVisible(voteBButton) || isVisible(skipButton);
const resultsTabActive = isVisible(resultsLoadMoreButton);
if (t && (t.tagName === 'INPUT' || t.tagName === 'TEXTAREA' || t.isContentEditable)) return;
if (e.key === 'ArrowLeft' && ratingTabActive) {
e.preventDefault();
voteAButton?.click();
} else if (e.key === 'ArrowRight' && ratingTabActive) {
e.preventDefault();
voteBButton?.click();
} else if (e.key === 'ArrowUp' && ratingTabActive) {
e.preventDefault();
skipButton?.click();
} else if ((e.key === 'z' || e.key === 'Z') && (e.ctrlKey || e.metaKey) && ratingTabActive) {
e.preventDefault();
backButton?.click();
} else if (e.key === 'ArrowDown') {
if (ratingTabActive) {
e.preventDefault();
backButton?.click();
}
if (resultsTabActive) {
e.preventDefault();
resultsLoadMoreButton?.click();
}
}
});
document.addEventListener('click', function (e) {
const voteBtn = findVoteButtonTarget(e.target);
if (voteBtn) {
const nowMs = Date.now();
const elapsedMs = nowMs - lastVoteAtMs;
if (elapsedMs < VOTE_COOLDOWN_MS) {
e.preventDefault();
e.stopPropagation();
if (typeof e.stopImmediatePropagation === 'function') e.stopImmediatePropagation();
showVoteBlockedMessage(VOTE_COOLDOWN_MS - elapsedMs);
return;
}
lastVoteAtMs = nowMs;
clearImageContainers();
return;
}
const a = e.target.closest('a[href="#back"]');
if (!a) return;
e.preventDefault();
document.querySelector('#btn-back-action button, button#btn-back-action')?.click();
}, true);
</script>
""",
css="""
.subtle-link button {
background: none !important;
border: none !important;
box-shadow: none !important;
color: #7a7a7a !important;
text-decoration: underline !important;
padding: 0 !important;
min-height: 0 !important;
font-size: 0.9em !important;
justify-content: flex-start !important;
}
.subtle-link button:hover {
color: #5a5a5a !important;
}
.subtle-link {
width: fit-content !important;
}
.subtle-link button {
width: fit-content !important;
}
.subtle-note {
color: #888;
font-size: 0.9em;
}
.rating-card {
width: 100%;
}
.rating-card-title {
min-height: 24px;
margin-bottom: 8px;
}
.rating-image-frame {
width: 100%;
border: 1px solid #e6e6e6;
border-radius: 8px;
background: #333;
display: flex;
align-items: center;
justify-content: center;
overflow: hidden;
}
.rating-image {
width: auto !important;
height: auto !important;
max-width: 100% !important;
max-height: 100% !important;
object-fit: contain !important;
object-position: center center !important;
display: block;
}
.subtle-back-link-wrap a {
color: #7a7a7a !important;
text-decoration: underline;
}
.subtle-back-link-wrap a:hover {
color: #5a5a5a !important;
}
.subtle-back-link-disabled {
color: #b8b8b8 !important;
pointer-events: none;
text-decoration: none;
}
.hidden-action-btn {
display: none !important;
}
#submit-status {
position: fixed;
left: 50%;
bottom: 20px;
transform: translateX(-50%);
z-index: 9998;
pointer-events: none;
min-height: 1.2em;
}
.submit-status-msg {
background: rgba(20, 20, 20, 0.92);
color: #fff;
padding: 8px 12px;
border-radius: 8px;
font-size: 0.92rem;
}
#results-gallery {
--explorer-thumb-ratio: 1 / 1;
}
#results-gallery button,
#results-gallery .thumbnail-item {
aspect-ratio: var(--explorer-thumb-ratio) !important;
}
#results-gallery img {
width: 100% !important;
height: 100% !important;
object-fit: contain !important;
background: #1f2937;
}
a {
padding: 0 !important;
}
""",
fill_width=True,
) as demo:
state = gr.State({})
rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref")
submit_key_store = gr.BrowserState(default_value="", storage_key="submit_key")
results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode")
results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier")
image_height_store = gr.BrowserState(default_value=768, storage_key="image_height")
groups_store = gr.BrowserState(default_value=[
group
for group in DATASETS[DEFAULT_DATASET]["groups"]
if group.endswith("_safe")
], storage_key="groups")
with gr.Tabs():
with gr.Tab("Image Quality Rater"):
gr.Markdown("Rate relative image quality. Choose the image with better quality, or select same quality if they are comparable. Both images are drawn from the same group to avoid cross-group bias.")
with gr.Row():
img_a = gr.HTML(elem_id="img-a")
img_b = gr.HTML(elem_id="img-b")
with gr.Row(equal_height=True):
btn_a = gr.Button("⬅️ Prefer A", variant="primary", elem_id="btn-vote-a")
with gr.Column(scale=0), gr.Group():
btn_skip = gr.Button("⬆️ Same Quality", elem_id="btn-skip")
btn_back_action = gr.Button("⬇️ Undo", elem_id="btn-back-action")
btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b")
with gr.Accordion("Settings", open=False):
groups_select = gr.CheckboxGroup(
choices=DATASETS[DEFAULT_DATASET]["groups"],
label="Categories",
show_label=True,
show_select_all=True
)
image_height_slider = gr.Slider(
minimum=512, maximum=2048, step=16, precision=0,
label="Image Size",
)
submit_key_tb = gr.Textbox(
value="",
type="password",
label="Submit Key",
elem_id="submit-key",
)
pair_details = gr.HTML(html_template="Dataset: <a href='https://huggingface.co/datasets/taigasan/e6-visual-ratings' target='_blank' rel='noopener noreferrer'>taigasan/e6-visual-ratings</a> ${value}")
submit_status = gr.HTML(html_template="<span class='submit-status-msg'>${value}</span>")
gr.HTML("<span class='subtle-note'>Keyboard Shortcuts: ⬅️ Vote A, ⬆️ Same Quality, ➡️ Vote B, ⬇️ or Ctrl+Z Undo</span>")
image_height = gr.HTML(html_template="<style>.rating-image-frame { height:${value}px; }</style>", apply_default_css=False)
(
results_summary_md,
results_rating_dd,
results_sort_dd,
results_classifier_dd,
results_score_distribution_plot,
results_distribution_state,
results_gallery,
results_load_more_btn,
selected_image_md,
results_page_meta_state,
results_page_offset_state,
) = add_results_tab(_pool_df)
outputs = [img_a, img_b, btn_back_action, pair_details, submit_status, state]
results_outputs = [
results_summary_md,
results_score_distribution_plot,
results_distribution_state,
results_gallery,
results_load_more_btn,
selected_image_md,
results_page_meta_state,
results_page_offset_state,
]
btn_a.click(fn=lambda s, g, k: vote("A", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
btn_b.click(fn=lambda s, g, k: vote("B", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
btn_skip.click(fn=lambda s, g, k: vote(None, s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden")
btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden")
submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden")
groups_select.change(fn=_on_groups_change, inputs=[groups_select, state], outputs=[*outputs, groups_store], queue=False, show_progress="hidden")
image_height_slider.change(fn=_on_image_height_change, inputs=[image_height_slider], outputs=[image_height_store, image_height], queue=False, show_progress="hidden")
results_rating_dd.change(fn=_normalize_rating_pref, inputs=[results_rating_dd], outputs=[rating_pref_store], queue=False, show_progress="hidden")
results_rating_dd.change(fn=_load_results, inputs=[results_rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden")
results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden")
results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden")
demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store, groups_store], outputs=[results_rating_dd, submit_key_tb, image_height_slider, image_height, groups_select, *outputs], queue=False, show_progress="hidden")
demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden")
demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden")
demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden")
results_load_more_btn.click(
fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o),
inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state],
outputs=[results_gallery, results_page_meta_state, results_page_offset_state, results_load_more_btn],
queue=False,
show_progress="hidden",
)
results_gallery.select(
fn=on_gallery_select,
inputs=[results_page_meta_state, results_distribution_state],
outputs=[selected_image_md, results_score_distribution_plot],
queue=False,
show_progress="hidden",
)
if __name__ == "__main__":
demo.launch()