Spaces:
Running
Running
| import gradio as gr | |
| import random | |
| import threading | |
| import time | |
| import uuid | |
| import os | |
| import html | |
| import sys | |
| from typing import Callable | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| from storage import VoteStorage | |
| from stats_from_logs import load_stats_by_md5 | |
| from explorer import ALLOWED_CLASSIFIER_FILTERS, add_results_tab, build_results_data, load_more_results, on_gallery_select | |
| DEBUG_MODE = os.getenv("DEBUG", "0").lower() in ("1", "true", "yes", "on") | |
| RATINGS_APP_TOKEN = os.getenv("RATINGS_APP_TOKEN") | |
| SUBMIT_KEY = os.getenv("RATINGS_SUBMIT_KEY") | |
| assert SUBMIT_KEY, "Missing required env var: RATINGS_SUBMIT_KEY" | |
| POOL_REPO_ID = "taigasan/e6-visual-ratings" | |
| VOTE_STORAGE = VoteStorage(mode="void" if DEBUG_MODE else "hf", token=RATINGS_APP_TOKEN) | |
| STATS_RELOAD_S = 30 * 60 | |
| # -- Pool dataset ----------------------------------------------------------- | |
| _pool_path = hf_hub_download( | |
| repo_id=POOL_REPO_ID, | |
| filename="pool.parquet", | |
| repo_type="dataset", | |
| token=RATINGS_APP_TOKEN | |
| ) | |
| _pool_df = pd.read_parquet(_pool_path) | |
| _pool_df[["wins", "losses", "ties", "votes", "winrate"]] = (0, 0, 0, 0, 0.0) | |
| WINS_LOC = _pool_df.columns.get_loc("wins") | |
| LOSSES_LOC = _pool_df.columns.get_loc("losses") | |
| TIES_LOC = _pool_df.columns.get_loc("ties") | |
| VOTES_LOC = _pool_df.columns.get_loc("votes") | |
| WINRATE_LOC = _pool_df.columns.get_loc("winrate") | |
| _md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) } | |
| _pool_lock = threading.Lock() | |
| _stats_last_loaded_at = 0.0 | |
| _explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"]) | |
| def _load_stats() -> None: | |
| VOTE_STORAGE.sync() | |
| load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN) | |
| n_missing = 0 | |
| with _pool_lock: | |
| VOTE_STORAGE.sync() | |
| stats_by_key = load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN) | |
| for md5, stats in stats_by_key.items(): | |
| if (idx := _md5_to_idx.get(md5)) is not None: | |
| _pool_df.iloc[idx, [WINS_LOC, LOSSES_LOC, TIES_LOC, VOTES_LOC, WINRATE_LOC]] = ( | |
| stats.wins, stats.losses, stats.ties, stats.votes, stats.winrate | |
| ) | |
| else: | |
| n_missing += 1 | |
| if n_missing: | |
| print(f"{n_missing} md5s have stats but are not in the pool!", file=sys.stderr) | |
| classifier_scores_path = hf_hub_download( | |
| repo_id=POOL_REPO_ID, | |
| filename="classifier_scores.parquet", | |
| repo_type="dataset", | |
| token=RATINGS_APP_TOKEN, | |
| ) | |
| validation_set_path = hf_hub_download( | |
| repo_id=POOL_REPO_ID, | |
| filename="validation_set.parquet", | |
| repo_type="dataset", | |
| token=RATINGS_APP_TOKEN, | |
| ) | |
| validation_df = pd.read_parquet( | |
| validation_set_path, | |
| columns=["group", "id", "md5", "rating", "sample_url", "image_url"], | |
| ) | |
| classifier_scores_df = pd.read_parquet(classifier_scores_path) | |
| assert {"classifier", "md5", "classifier_score", "percentile"}.issubset(classifier_scores_df.columns), "classifier_scores.parquet missing expected columns" | |
| classifier_scores_df = classifier_scores_df[["classifier", "md5", "classifier_score", "percentile"]] | |
| classifier_scores_df["classifier"] = classifier_scores_df["classifier"].astype(str) | |
| classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str) | |
| validation_df["md5"] = validation_df["md5"].astype(str) | |
| global _explorer_df | |
| _explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many") | |
| def _stats_reloader() -> None: | |
| while True: | |
| time.sleep(STATS_RELOAD_S) | |
| _load_stats() | |
| _load_stats() | |
| threading.Thread(target=_stats_reloader, daemon=True).start() | |
| def _pick_from(df: pd.DataFrame, *, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None: | |
| if len(df) < 2: | |
| return None | |
| sample = df.sample(2, weights=weights, replace=False) | |
| return sample.iloc[0], sample.iloc[1], len(df) | |
| def _pick_similar( | |
| df: pd.DataFrame, | |
| distance: Callable[[pd.DataFrame, pd.Series], pd.Series], | |
| *, | |
| weights: Callable[[pd.DataFrame], pd.Series] | None = None, | |
| other_df: pd.DataFrame | None = None, | |
| ) -> tuple[pd.Series, pd.Series, int] | None: | |
| if len(df) < 2: | |
| return None | |
| if other_df is None: | |
| other_df = df | |
| elif len(other_df) < 2: | |
| return None | |
| weight_vals: pd.Series | None = None | |
| if weights is not None: | |
| weight_vals = weights(df) | |
| first = df.sample(weights=weight_vals).iloc[0] | |
| weight_vals = 1.0 / (1.0 + distance(other_df, first)) | |
| while True: | |
| other = other_df.sample(weights=weight_vals).iloc[0] | |
| if other["md5"] != first["md5"]: | |
| return first, other, len(df) | |
| def _pool_fetch_pair(group: str) -> tuple[pd.Series, pd.Series, int, str]: | |
| gdf = _pool_df[_pool_df["group"] == group] | |
| voted = gdf[gdf["votes"] > 0] | |
| votes = voted["votes"] | |
| # Pair first-time winners. | |
| picked = _pick_from(voted[(votes == 1) & (voted["wins"] == 1)]) | |
| if picked is not None: | |
| return *picked, "new-winners" | |
| # Pair first-time losers. | |
| picked = _pick_from(voted[(votes == 1) & (voted["losses"] == 1)]) | |
| if picked is not None: | |
| return *picked, "new-losers" | |
| def record_distance(df: pd.DataFrame, pivot: pd.Series) -> pd.Series: | |
| return ( | |
| (df["wins"] - pivot["wins"])**2 + | |
| (df["losses"] - pivot["losses"])**2 | |
| )**0.75 # L2 is a bit too loose | |
| # Link cliques to main network and break ties. | |
| nonties = votes - voted["ties"] | |
| picked = _pick_similar( | |
| voted[(nonties == 0) | (votes == 2)], | |
| record_distance, | |
| other_df=voted[nonties > 3], | |
| ) | |
| if picked is not None: | |
| return *picked, "sparse" | |
| # Introduce new images. | |
| if len(voted) < 8 or random.random() < 0.33: | |
| unvoted = gdf[gdf["votes"] == 0] | |
| match len(unvoted): | |
| case 0: | |
| pass | |
| case 1: | |
| return unvoted.iloc[0], voted.iloc[0], 1, "new" | |
| case _: | |
| picked = _pick_from(unvoted) | |
| assert picked is not None | |
| return *picked, "new" | |
| # Vote-weighted random sampling between similar winrates, slighlty biased against picking losers. | |
| picked = _pick_similar( | |
| voted, record_distance, | |
| weights=lambda df: 1.0 / (df["votes"]**1.25 + 0.1 * df["losses"]), | |
| ) | |
| assert picked is not None | |
| return *picked, "fair-probe" | |
| def _row_image_url(row) -> str: | |
| sample_url = row.get("sample_url") | |
| if isinstance(sample_url, str) and sample_url: | |
| return sample_url | |
| image_url = row.get("image_url") | |
| if isinstance(image_url, str) and image_url: | |
| return image_url | |
| return '' | |
| DATASETS: dict[str, dict] = { | |
| "pool": { | |
| "fetch_pair": _pool_fetch_pair, | |
| "get_id": lambda row: row["md5"], | |
| "get_image": _row_image_url, | |
| "groups": sorted(_pool_df["group"].unique()), | |
| }, | |
| } | |
| DEFAULT_DATASET = list(DATASETS.keys())[0] | |
| def _format_rating_post_title(post_id: int, votes: int, label: str) -> str: | |
| return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}" | |
| def _render_current(state: dict, submit_status: str = "") -> tuple: | |
| votes_a = _pool_df.iloc[_md5_to_idx[state["key_a"]], VOTES_LOC] | |
| votes_b = _pool_df.iloc[_md5_to_idx[state["key_b"]], VOTES_LOC] | |
| title_a = _format_rating_post_title(state["id_a"], votes_a, "Image A") | |
| title_b = _format_rating_post_title(state["id_b"], votes_b, "Image B") | |
| img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\">{title_a}</div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>" | |
| img_b_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\">{title_b}</div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_b'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>" | |
| can_go_back = bool(state.get("pending", ())) | |
| pair_details = f"/ {state['group']} / {state.get('pair_reason', 'unknown')}" | |
| return img_a_html, img_b_html, gr.Button(interactive=can_go_back), html.escape(pair_details), html.escape(submit_status), state | |
| def _normalize_rating_pref(pref: str | None) -> str: | |
| return pref if pref in ("safe", "all") else "safe" | |
| def _initial_load(state: dict, rating_pref: str | None, submit_key: str | None, image_height: str, groups: list[str]): | |
| rating_pref = _normalize_rating_pref(rating_pref) | |
| submit_key = _normalize_submit_key(submit_key) | |
| return rating_pref, submit_key, image_height, image_height, groups, *new_round(DEFAULT_DATASET, groups, state) | |
| def _on_groups_change(groups: list[str], state: dict): | |
| return *new_round(DEFAULT_DATASET, groups, state), groups | |
| def _on_image_height_change(image_height: str) -> tuple[str, str]: | |
| return image_height, image_height | |
| def _normalize_submit_key(submit_key: str | None) -> str: | |
| return (submit_key or "").strip() | |
| def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame: | |
| return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0]) | |
| def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame: | |
| if rating_pref == "all": | |
| rating_filtered = _explorer_df | |
| else: | |
| assert rating_pref == "safe", f"Unsupported rating preference: {rating_pref}" | |
| rating_filtered = _explorer_df[_explorer_df["rating"] == "s"] | |
| assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}" | |
| return rating_filtered[rating_filtered["classifier"] == classifier_name] | |
| def _load_results(rating_pref: str, sort_mode: str, classifier_filter: str): | |
| rating_pref = _normalize_rating_pref(rating_pref) | |
| sort_mode = _normalize_sort_mode(sort_mode) | |
| classifier_name = _normalize_classifier_filter(classifier_filter) | |
| filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name) | |
| summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data( | |
| filtered_explorer_df, | |
| _explorer_df, | |
| rating_pref, | |
| sort_mode, | |
| classifier_name, | |
| ) | |
| return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset | |
| def _normalize_sort_mode(sort_mode: str | None) -> str: | |
| if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"): | |
| return sort_mode | |
| return "Default" | |
| def _normalize_classifier_filter(classifier_name: str | None) -> str: | |
| if classifier_name in ALLOWED_CLASSIFIER_FILTERS: | |
| return str(classifier_name) | |
| return ALLOWED_CLASSIFIER_FILTERS[0] | |
| # -- Gradio callbacks ------------------------------------------------------- | |
| def new_round(dataset_name: str, groups: list[str], state: dict) -> tuple: | |
| if not groups: | |
| return "", "", gr.skip(), "", "Please select at least one group.", state | |
| cfg = DATASETS[dataset_name] | |
| group = random.choice(groups) | |
| row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group) | |
| pair_reason = f"{pair_reason} ({reason_remaining})" | |
| state.setdefault("session_id", uuid.uuid4().hex) | |
| key_a = cfg["get_id"](row_a) | |
| key_b = cfg["get_id"](row_b) | |
| id_a = int(row_a["id"]) | |
| id_b = int(row_b["id"]) | |
| state.update(dataset=dataset_name, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason) | |
| url_a = cfg["get_image"](row_a) | |
| url_b = cfg["get_image"](row_b) | |
| state["url_a"] = url_a | |
| state["url_b"] = url_b | |
| return _render_current(state) | |
| def _queue_decision(winner: str | None, state: dict): | |
| assert state.get("session_id"), "Missing session_id: refusing to record vote" | |
| pending = state.setdefault("pending", []) | |
| pending.append({ | |
| "winner": winner, | |
| "key_a": state["key_a"], | |
| "key_b": state["key_b"], | |
| "id_a": state["id_a"], | |
| "id_b": state["id_b"], | |
| "url_a": state["url_a"], | |
| "url_b": state["url_b"], | |
| "dataset": state["dataset"], | |
| "group": state["group"], | |
| "pair_reason": state.get("pair_reason", ""), | |
| "session_id": state["session_id"], | |
| }) | |
| if len(pending) > 1: | |
| VOTE_STORAGE.queue_row(pending.pop(0)) | |
| def _add_vote(idx: int, col_loc: int, delta: int = 1) -> None: | |
| _pool_df.iloc[idx, [col_loc, VOTES_LOC]] += delta | |
| wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]] | |
| _pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1) | |
| def vote(winner: str | None, state: dict, groups: list[str], submit_key: str | None) -> tuple: | |
| if _normalize_submit_key(submit_key) != SUBMIT_KEY: | |
| return _render_current(state, "Wrong submission key.") | |
| if not groups: | |
| return "", "", gr.skip(), "", "Please select at least one group.", state | |
| _queue_decision(winner, state) | |
| a_idx = _md5_to_idx[state["key_a"]] | |
| b_idx = _md5_to_idx[state["key_b"]] | |
| with _pool_lock: | |
| match winner: | |
| case "A": | |
| _add_vote(a_idx, WINS_LOC) | |
| _add_vote(b_idx, LOSSES_LOC) | |
| case "B": | |
| _add_vote(a_idx, LOSSES_LOC) | |
| _add_vote(b_idx, WINS_LOC) | |
| case None: | |
| _add_vote(a_idx, TIES_LOC) | |
| _add_vote(b_idx, TIES_LOC) | |
| case _: | |
| raise AssertionError | |
| return new_round(state["dataset"], groups, state) | |
| def go_back(state: dict) -> tuple: | |
| pending = state.setdefault("pending", []) | |
| if pending: | |
| last = pending.pop() | |
| state.update( | |
| dataset=last["dataset"], | |
| key_a=last["key_a"], | |
| key_b=last["key_b"], | |
| id_a=last["id_a"], | |
| id_b=last["id_b"], | |
| url_a=last["url_a"], | |
| url_b=last["url_b"], | |
| group=last["group"], | |
| pair_reason=last.get("pair_reason", ""), | |
| ) | |
| a_idx = _md5_to_idx[state["key_a"]] | |
| b_idx = _md5_to_idx[state["key_b"]] | |
| with _pool_lock: | |
| match last["winner"]: | |
| case "A": | |
| _add_vote(a_idx, WINS_LOC, -1) | |
| _add_vote(b_idx, LOSSES_LOC, -1) | |
| case "B": | |
| _add_vote(a_idx, LOSSES_LOC, -1) | |
| _add_vote(b_idx, WINS_LOC, -1) | |
| case None: | |
| _add_vote(a_idx, TIES_LOC, -1) | |
| _add_vote(b_idx, TIES_LOC, -1) | |
| case _: | |
| raise AssertionError | |
| return _render_current(state) | |
| # -- UI --------------------------------------------------------------------- | |
| with gr.Blocks( | |
| title="e621 Visual Ratings", | |
| head=""" | |
| <script> | |
| const VOTE_COOLDOWN_MS = 1500; | |
| let lastVoteAtMs = 0; | |
| let voteToastTimer = null; | |
| function showVoteToast(message) { | |
| let toast = document.getElementById('vote-cooldown-toast'); | |
| if (!toast) { | |
| toast = document.createElement('div'); | |
| toast.id = 'vote-cooldown-toast'; | |
| toast.style.position = 'fixed'; | |
| toast.style.left = '50%'; | |
| toast.style.bottom = '20px'; | |
| toast.style.transform = 'translateX(-50%)'; | |
| toast.style.background = 'rgba(20, 20, 20, 0.92)'; | |
| toast.style.color = '#fff'; | |
| toast.style.padding = '8px 12px'; | |
| toast.style.borderRadius = '8px'; | |
| toast.style.fontSize = '0.92rem'; | |
| toast.style.zIndex = '9999'; | |
| toast.style.pointerEvents = 'none'; | |
| toast.style.opacity = '0'; | |
| toast.style.transition = 'opacity 120ms ease'; | |
| document.body.appendChild(toast); | |
| } | |
| toast.textContent = message; | |
| toast.style.opacity = '1'; | |
| if (voteToastTimer) clearTimeout(voteToastTimer); | |
| voteToastTimer = setTimeout(function () { | |
| toast.style.opacity = '0'; | |
| }, 1400); | |
| } | |
| function showVoteBlockedMessage(remainingMs) { | |
| const remainingS = Math.max(0.1, remainingMs / 1000).toFixed(1); | |
| showVoteToast(`Please wait ${remainingS}s before submitting again.`); | |
| } | |
| function findVoteButtonTarget(target) { | |
| return target?.closest('#btn-vote-a button, button#btn-vote-a, #btn-vote-b button, button#btn-vote-b, #btn-skip button, button#btn-skip'); | |
| } | |
| function clearImageContainers() { | |
| const leftImg = document.querySelector('#img-a img'); | |
| const rightImg = document.querySelector('#img-b img'); | |
| if (leftImg) { | |
| leftImg.src = ''; | |
| leftImg.removeAttribute('srcset'); | |
| } | |
| if (rightImg) { | |
| rightImg.src = ''; | |
| rightImg.removeAttribute('srcset'); | |
| } | |
| } | |
| function isVisible(el) { | |
| return !!(el && el.offsetParent !== null); | |
| } | |
| window.addEventListener('keydown', function (e) { | |
| const t = e.target; | |
| const voteAButton = document.querySelector('#btn-vote-a button, button#btn-vote-a'); | |
| const voteBButton = document.querySelector('#btn-vote-b button, button#btn-vote-b'); | |
| const skipButton = document.querySelector('#btn-skip button, button#btn-skip'); | |
| const backButton = document.querySelector('#btn-back-action button, button#btn-back-action'); | |
| const resultsLoadMoreButton = document.querySelector('#btn-results-load-more button, button#btn-results-load-more'); | |
| const ratingTabActive = isVisible(voteAButton) || isVisible(voteBButton) || isVisible(skipButton); | |
| const resultsTabActive = isVisible(resultsLoadMoreButton); | |
| if (t && (t.tagName === 'INPUT' || t.tagName === 'TEXTAREA' || t.isContentEditable)) return; | |
| if (e.key === 'ArrowLeft' && ratingTabActive) { | |
| e.preventDefault(); | |
| voteAButton?.click(); | |
| } else if (e.key === 'ArrowRight' && ratingTabActive) { | |
| e.preventDefault(); | |
| voteBButton?.click(); | |
| } else if (e.key === 'ArrowUp' && ratingTabActive) { | |
| e.preventDefault(); | |
| skipButton?.click(); | |
| } else if ((e.key === 'z' || e.key === 'Z') && (e.ctrlKey || e.metaKey) && ratingTabActive) { | |
| e.preventDefault(); | |
| backButton?.click(); | |
| } else if (e.key === 'ArrowDown') { | |
| if (ratingTabActive) { | |
| e.preventDefault(); | |
| backButton?.click(); | |
| } | |
| if (resultsTabActive) { | |
| e.preventDefault(); | |
| resultsLoadMoreButton?.click(); | |
| } | |
| } | |
| }); | |
| document.addEventListener('click', function (e) { | |
| const voteBtn = findVoteButtonTarget(e.target); | |
| if (voteBtn) { | |
| const nowMs = Date.now(); | |
| const elapsedMs = nowMs - lastVoteAtMs; | |
| if (elapsedMs < VOTE_COOLDOWN_MS) { | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| if (typeof e.stopImmediatePropagation === 'function') e.stopImmediatePropagation(); | |
| showVoteBlockedMessage(VOTE_COOLDOWN_MS - elapsedMs); | |
| return; | |
| } | |
| lastVoteAtMs = nowMs; | |
| clearImageContainers(); | |
| return; | |
| } | |
| const a = e.target.closest('a[href="#back"]'); | |
| if (!a) return; | |
| e.preventDefault(); | |
| document.querySelector('#btn-back-action button, button#btn-back-action')?.click(); | |
| }, true); | |
| </script> | |
| """, | |
| css=""" | |
| .subtle-link button { | |
| background: none !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| color: #7a7a7a !important; | |
| text-decoration: underline !important; | |
| padding: 0 !important; | |
| min-height: 0 !important; | |
| font-size: 0.9em !important; | |
| justify-content: flex-start !important; | |
| } | |
| .subtle-link button:hover { | |
| color: #5a5a5a !important; | |
| } | |
| .subtle-link { | |
| width: fit-content !important; | |
| } | |
| .subtle-link button { | |
| width: fit-content !important; | |
| } | |
| .subtle-note { | |
| color: #888; | |
| font-size: 0.9em; | |
| } | |
| .rating-card { | |
| width: 100%; | |
| } | |
| .rating-card-title { | |
| min-height: 24px; | |
| margin-bottom: 8px; | |
| } | |
| .rating-image-frame { | |
| width: 100%; | |
| border: 1px solid #e6e6e6; | |
| border-radius: 8px; | |
| background: #333; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| overflow: hidden; | |
| } | |
| .rating-image { | |
| width: auto !important; | |
| height: auto !important; | |
| max-width: 100% !important; | |
| max-height: 100% !important; | |
| object-fit: contain !important; | |
| object-position: center center !important; | |
| display: block; | |
| } | |
| .subtle-back-link-wrap a { | |
| color: #7a7a7a !important; | |
| text-decoration: underline; | |
| } | |
| .subtle-back-link-wrap a:hover { | |
| color: #5a5a5a !important; | |
| } | |
| .subtle-back-link-disabled { | |
| color: #b8b8b8 !important; | |
| pointer-events: none; | |
| text-decoration: none; | |
| } | |
| .hidden-action-btn { | |
| display: none !important; | |
| } | |
| #submit-status { | |
| position: fixed; | |
| left: 50%; | |
| bottom: 20px; | |
| transform: translateX(-50%); | |
| z-index: 9998; | |
| pointer-events: none; | |
| min-height: 1.2em; | |
| } | |
| .submit-status-msg { | |
| background: rgba(20, 20, 20, 0.92); | |
| color: #fff; | |
| padding: 8px 12px; | |
| border-radius: 8px; | |
| font-size: 0.92rem; | |
| } | |
| #results-gallery { | |
| --explorer-thumb-ratio: 1 / 1; | |
| } | |
| #results-gallery button, | |
| #results-gallery .thumbnail-item { | |
| aspect-ratio: var(--explorer-thumb-ratio) !important; | |
| } | |
| #results-gallery img { | |
| width: 100% !important; | |
| height: 100% !important; | |
| object-fit: contain !important; | |
| background: #1f2937; | |
| } | |
| a { | |
| padding: 0 !important; | |
| } | |
| """, | |
| fill_width=True, | |
| ) as demo: | |
| state = gr.State({}) | |
| rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref") | |
| submit_key_store = gr.BrowserState(default_value="", storage_key="submit_key") | |
| results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode") | |
| results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier") | |
| image_height_store = gr.BrowserState(default_value=768, storage_key="image_height") | |
| groups_store = gr.BrowserState(default_value=[ | |
| group | |
| for group in DATASETS[DEFAULT_DATASET]["groups"] | |
| if group.endswith("_safe") | |
| ], storage_key="groups") | |
| with gr.Tabs(): | |
| with gr.Tab("Image Quality Rater"): | |
| gr.Markdown("Rate relative image quality. Choose the image with better quality, or select same quality if they are comparable. Both images are drawn from the same group to avoid cross-group bias.") | |
| with gr.Row(): | |
| img_a = gr.HTML(elem_id="img-a") | |
| img_b = gr.HTML(elem_id="img-b") | |
| with gr.Row(equal_height=True): | |
| btn_a = gr.Button("⬅️ Prefer A", variant="primary", elem_id="btn-vote-a") | |
| with gr.Column(scale=0), gr.Group(): | |
| btn_skip = gr.Button("⬆️ Same Quality", elem_id="btn-skip") | |
| btn_back_action = gr.Button("⬇️ Undo", elem_id="btn-back-action") | |
| btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b") | |
| with gr.Accordion("Settings", open=False): | |
| groups_select = gr.CheckboxGroup( | |
| choices=DATASETS[DEFAULT_DATASET]["groups"], | |
| label="Categories", | |
| show_label=True, | |
| show_select_all=True | |
| ) | |
| image_height_slider = gr.Slider( | |
| minimum=512, maximum=2048, step=16, precision=0, | |
| label="Image Size", | |
| ) | |
| submit_key_tb = gr.Textbox( | |
| value="", | |
| type="password", | |
| label="Submit Key", | |
| elem_id="submit-key", | |
| ) | |
| pair_details = gr.HTML(html_template="Dataset: <a href='https://huggingface.co/datasets/taigasan/e6-visual-ratings' target='_blank' rel='noopener noreferrer'>taigasan/e6-visual-ratings</a> ${value}") | |
| submit_status = gr.HTML(html_template="<span class='submit-status-msg'>${value}</span>") | |
| gr.HTML("<span class='subtle-note'>Keyboard Shortcuts: ⬅️ Vote A, ⬆️ Same Quality, ➡️ Vote B, ⬇️ or Ctrl+Z Undo</span>") | |
| image_height = gr.HTML(html_template="<style>.rating-image-frame { height:${value}px; }</style>", apply_default_css=False) | |
| ( | |
| results_summary_md, | |
| results_rating_dd, | |
| results_sort_dd, | |
| results_classifier_dd, | |
| results_score_distribution_plot, | |
| results_distribution_state, | |
| results_gallery, | |
| results_load_more_btn, | |
| selected_image_md, | |
| results_page_meta_state, | |
| results_page_offset_state, | |
| ) = add_results_tab(_pool_df) | |
| outputs = [img_a, img_b, btn_back_action, pair_details, submit_status, state] | |
| results_outputs = [ | |
| results_summary_md, | |
| results_score_distribution_plot, | |
| results_distribution_state, | |
| results_gallery, | |
| results_load_more_btn, | |
| selected_image_md, | |
| results_page_meta_state, | |
| results_page_offset_state, | |
| ] | |
| btn_a.click(fn=lambda s, g, k: vote("A", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") | |
| btn_b.click(fn=lambda s, g, k: vote("B", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") | |
| btn_skip.click(fn=lambda s, g, k: vote(None, s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") | |
| btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden") | |
| submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden") | |
| groups_select.change(fn=_on_groups_change, inputs=[groups_select, state], outputs=[*outputs, groups_store], queue=False, show_progress="hidden") | |
| image_height_slider.change(fn=_on_image_height_change, inputs=[image_height_slider], outputs=[image_height_store, image_height], queue=False, show_progress="hidden") | |
| results_rating_dd.change(fn=_normalize_rating_pref, inputs=[results_rating_dd], outputs=[rating_pref_store], queue=False, show_progress="hidden") | |
| results_rating_dd.change(fn=_load_results, inputs=[results_rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") | |
| results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden") | |
| results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") | |
| results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden") | |
| results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden") | |
| demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store, groups_store], outputs=[results_rating_dd, submit_key_tb, image_height_slider, image_height, groups_select, *outputs], queue=False, show_progress="hidden") | |
| demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") | |
| demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden") | |
| demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden") | |
| results_load_more_btn.click( | |
| fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o), | |
| inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state], | |
| outputs=[results_gallery, results_page_meta_state, results_page_offset_state, results_load_more_btn], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| results_gallery.select( | |
| fn=on_gallery_select, | |
| inputs=[results_page_meta_state, results_distribution_state], | |
| outputs=[selected_image_md, results_score_distribution_plot], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |