Spaces:
Running
Running
| import gradio as gr | |
| import random | |
| import threading | |
| import time | |
| import uuid | |
| import os | |
| import html | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| from storage import VoteStorage | |
| from stats_from_logs import load_stats_by_md5 | |
| from explorer import ALLOWED_CLASSIFIER_FILTERS, add_results_tab, build_results_data, load_more_results, on_gallery_select | |
| DEBUG_MODE = os.getenv("DEBUG", "0").lower() in ("1", "true", "yes", "on") | |
| RATINGS_APP_TOKEN = os.getenv("RATINGS_APP_TOKEN") | |
| SUBMIT_KEY = os.getenv("RATINGS_SUBMIT_KEY") | |
| assert SUBMIT_KEY, "Missing required env var: RATINGS_SUBMIT_KEY" | |
| POOL_REPO_ID = "taigasan/e6-visual-ratings" | |
| VOTE_STORAGE = VoteStorage(mode="void" if DEBUG_MODE else "hf", token=RATINGS_APP_TOKEN) | |
| STATS_RELOAD_S = 30 * 60 | |
| # -- Pool dataset ----------------------------------------------------------- | |
| _pool_path = hf_hub_download( | |
| repo_id=POOL_REPO_ID, | |
| filename="pool.parquet", | |
| repo_type="dataset", | |
| token=RATINGS_APP_TOKEN | |
| ) | |
| _pool_df = pd.read_parquet(_pool_path) | |
| _pool_group_dfs = {g: gdf for g, gdf in _pool_df.groupby("group")} | |
| _stats_lock = threading.Lock() | |
| _stats_last_loaded_at = 0.0 | |
| _stats_by_key: dict[str, tuple[int, int]] = {} | |
| _explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"]) | |
| def _reload_stats_if_due(force: bool = False): | |
| global _stats_last_loaded_at, _stats_by_key, _explorer_df | |
| now = time.time() | |
| if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S: | |
| return | |
| with _stats_lock: | |
| now = time.time() | |
| if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S: | |
| return | |
| _stats_by_key = load_stats_by_md5( | |
| repo_id=POOL_REPO_ID, | |
| token=RATINGS_APP_TOKEN, | |
| ) | |
| classifier_scores_path = hf_hub_download( | |
| repo_id=POOL_REPO_ID, | |
| filename="classifier_scores.parquet", | |
| repo_type="dataset", | |
| token=RATINGS_APP_TOKEN, | |
| ) | |
| validation_set_path = hf_hub_download( | |
| repo_id=POOL_REPO_ID, | |
| filename="validation_set.parquet", | |
| repo_type="dataset", | |
| token=RATINGS_APP_TOKEN, | |
| ) | |
| validation_df = pd.read_parquet( | |
| validation_set_path, | |
| columns=["group", "id", "md5", "rating", "sample_url", "image_url"], | |
| ) | |
| classifier_scores_df = pd.read_parquet(classifier_scores_path) | |
| assert {"classifier", "md5", "classifier_score", "percentile"}.issubset(classifier_scores_df.columns), "classifier_scores.parquet missing expected columns" | |
| classifier_scores_df = classifier_scores_df[["classifier", "md5", "classifier_score", "percentile"]] | |
| classifier_scores_df["classifier"] = classifier_scores_df["classifier"].astype(str) | |
| classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str) | |
| validation_df["md5"] = validation_df["md5"].astype(str) | |
| _explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many") | |
| _stats_last_loaded_at = now | |
| _reload_stats_if_due(force=True) | |
| def _pool_fetch_pair(group_name: str) -> tuple: | |
| gdf = _pool_group_dfs[group_name] | |
| assert len(gdf) >= 2, f"Not enough rows for group: {group_name}" | |
| md5_keys = gdf["md5"].astype(str) | |
| wins = md5_keys.map(lambda k: _stats_by_key.get(k, (0, 0))[0]) | |
| losses = md5_keys.map(lambda k: _stats_by_key.get(k, (0, 0))[1]) | |
| def _pick_from_mask(mask: pd.Series): | |
| candidate_df = gdf[mask] | |
| if len(candidate_df) < 2: | |
| return None | |
| sample = candidate_df.sample(2, replace=False) | |
| return sample.iloc[0], sample.iloc[1] | |
| # 1) Repeat the lowest-margin edge participating in a cycle. (To prevent deadlock, stop if all margins are 4+.) | |
| # a) If deadlocked on a cycle with 4+ images and no inner cycles, sample a random missing edge inside the cycle. | |
| # 2) Pair images that both have wins only . (One of them will lose/tie. Stop when there is only one left.) | |
| # 3) Pair images that both have losses only. (One of them will win/tie. Stop when there is only one left.) | |
| # 4) Pair images with only 2 edges. | |
| # 5) X% chance, re-sample an existing edge, inversely proportional to existing number of samples. | |
| # 6) Y% chance, sample a random missing edge between images already sampled. | |
| # 7) Pair an unsampled image with a random sampled image. | |
| # 2) Pair images that currently have wins-only records. | |
| picked = _pick_from_mask((wins > 0) & (losses == 0)) | |
| if picked is not None: | |
| return picked[0], picked[1], "wins-only" | |
| # 3) Pair images that currently have losses-only records. | |
| picked = _pick_from_mask((wins == 0) & (losses > 0)) | |
| if picked is not None: | |
| return picked[0], picked[1], "losses-only" | |
| # 4) Pair images that currently have exactly 2 total edges. | |
| vote_totals = wins + losses | |
| picked = _pick_from_mask(vote_totals == 2) | |
| if picked is not None: | |
| return picked[0], picked[1], "total_votes=2" | |
| # 7) Prefer pairing an unsampled image with a random previously sampled image. | |
| unsampled_mask = vote_totals == 0 | |
| if unsampled_mask.any(): | |
| unsampled_row = gdf[unsampled_mask].sample(1).iloc[0] | |
| sampled_df = gdf[~unsampled_mask] | |
| if len(sampled_df) >= 1: | |
| sampled_row = sampled_df.sample(1).iloc[0] | |
| else: | |
| sampled_row = gdf.drop(index=unsampled_row.name).sample(1).iloc[0] | |
| return unsampled_row, sampled_row, "unsampled+sampled" | |
| # 8) Safety fall back to low-vote weighted sampling. | |
| sample_weights = 1.0 / (vote_totals + 1.0) | |
| sample = gdf.sample(2, weights=sample_weights, replace=False) | |
| return sample.iloc[0], sample.iloc[1], "low-vote" | |
| def _row_image_url(row) -> str: | |
| sample_url = row.get("sample_url") | |
| if isinstance(sample_url, str) and sample_url: | |
| return sample_url | |
| image_url = row.get("image_url") | |
| if isinstance(image_url, str) and image_url: | |
| return image_url | |
| return '' | |
| DATASETS: dict[str, dict] = { | |
| "pool": { | |
| "fetch_pair": _pool_fetch_pair, | |
| "get_id": lambda row: row["md5"], | |
| "get_image": _row_image_url, | |
| "groups": {g: g for g in sorted(_pool_df["group"].unique())}, | |
| }, | |
| } | |
| DEFAULT_DATASET = list(DATASETS.keys())[0] | |
| def _select_groups(cfg: dict, rating_pref: str) -> list[str]: | |
| groups = list(cfg["groups"].keys()) | |
| if rating_pref == "all": | |
| return groups | |
| return [g for g in groups if g.endswith(f"_{rating_pref}")] | |
| def _commit_oldest_pending(state: dict): | |
| pending = state.setdefault("pending", []) | |
| if len(pending) <= 1: | |
| return | |
| oldest = pending.pop(0) | |
| if oldest.get("winner") in ("A", "B"): | |
| _apply_local_stats_update(oldest["winner"], oldest["key_a"], oldest["key_b"]) | |
| threading.Thread(target=VOTE_STORAGE.append_vote_row, args=(oldest.copy(), oldest.get("winner")), daemon=True).start() | |
| def _apply_local_stats_update(winner: str, key_a: str, key_b: str): | |
| assert winner in ("A", "B") | |
| with _stats_lock: | |
| wins_a, losses_a = _stats_by_key.get(str(key_a), (0, 0)) | |
| wins_b, losses_b = _stats_by_key.get(str(key_b), (0, 0)) | |
| if winner == "A": | |
| _stats_by_key[str(key_a)] = (wins_a + 1, losses_a) | |
| _stats_by_key[str(key_b)] = (wins_b, losses_b + 1) | |
| else: | |
| _stats_by_key[str(key_a)] = (wins_a, losses_a + 1) | |
| _stats_by_key[str(key_b)] = (wins_b + 1, losses_b) | |
| def _format_rating_post_row(post_id: int, wins: int, losses: int, label: str | None = None) -> str: | |
| total_votes = wins + losses | |
| url = f"https://e621.net/posts/{post_id}" | |
| row = f"{url} | Times rated: {total_votes}" | |
| return f"{label}: {row}" if label else row | |
| def _render_current(state: dict, submit_status: str = "") -> tuple: | |
| _reload_stats_if_due() | |
| wins_a, losses_a = _stats_by_key.get(str(state["key_a"]), (0, 0)) | |
| wins_b, losses_b = _stats_by_key.get(str(state["key_b"]), (0, 0)) | |
| title_a = "Image A" | |
| title_b = "Image B" | |
| img_a_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_a)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_a'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>" | |
| img_b_html = f"<div class=\"rating-card\"><div class=\"rating-card-title\"><strong>{html.escape(title_b)}</strong></div><div class=\"rating-image-frame\"><img src=\"{html.escape(state['url_b'])}\" class=\"rating-image\" loading=\"eager\" referrerpolicy=\"no-referrer\"></div></div>" | |
| link_a = _format_rating_post_row(state["id_a"], wins_a, losses_a, label="Image A") | |
| link_b = _format_rating_post_row(state["id_b"], wins_b, losses_b, label="Image B") | |
| can_go_back = bool(state.get("can_go_back")) | |
| back_md = "[Undo Rating (Ctrl+z)](#back)" if can_go_back else "<span class='subtle-back-link-disabled'>Undo Rating (Ctrl+z)</span>" | |
| group_md = f"<span class='subtle-note'>Group: {state['group']}</span>" | |
| pair_reason = state.get("pair_reason", "") | |
| pair_reason_md = f"<span class='subtle-note'>Pair: {html.escape(pair_reason)}</span>" if pair_reason else "" | |
| status_md = f"<span class='submit-status-msg'>{html.escape(submit_status)}</span>" if submit_status else "" | |
| return img_a_html, img_b_html, link_a, link_b, back_md, group_md, pair_reason_md, status_md, state | |
| def _normalize_rating_pref(pref: str | None) -> str: | |
| return pref if pref in ("safe", "all") else "safe" | |
| def _initial_load(state: dict, pref: str | None, submit_key: str | None): | |
| rating_pref = _normalize_rating_pref(pref) | |
| submit_key = _normalize_submit_key(submit_key) | |
| return rating_pref, submit_key, *new_round(DEFAULT_DATASET, rating_pref, state) | |
| def _on_rating_change(rating_pref: str, state: dict): | |
| rating_pref = _normalize_rating_pref(rating_pref) | |
| return *new_round(DEFAULT_DATASET, rating_pref, state), rating_pref | |
| def _normalize_submit_key(submit_key: str | None) -> str: | |
| return submit_key or "" | |
| def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame: | |
| return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0]) | |
| def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame: | |
| if rating_pref == "all": | |
| rating_filtered = _explorer_df | |
| else: | |
| assert rating_pref == "safe", f"Unsupported rating preference: {rating_pref}" | |
| rating_filtered = _explorer_df[_explorer_df["rating"] == "s"] | |
| assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}" | |
| return rating_filtered[rating_filtered["classifier"] == classifier_name] | |
| def _load_results(rating_pref_value: str, sort_mode_value: str, classifier_filter_value: str): | |
| rating_pref = _normalize_rating_pref(rating_pref_value) | |
| sort_mode = _normalize_sort_mode(sort_mode_value) | |
| classifier_name = _normalize_classifier_filter(classifier_filter_value) | |
| _reload_stats_if_due() | |
| filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name) | |
| summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data( | |
| filtered_explorer_df, | |
| _explorer_df, | |
| rating_pref, | |
| sort_mode, | |
| classifier_name, | |
| ) | |
| return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset | |
| def _normalize_sort_mode(sort_mode: str | None) -> str: | |
| if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"): | |
| return sort_mode | |
| return "Default" | |
| def _normalize_classifier_filter(classifier_name: str | None) -> str: | |
| if classifier_name in ALLOWED_CLASSIFIER_FILTERS: | |
| return str(classifier_name) | |
| return ALLOWED_CLASSIFIER_FILTERS[0] | |
| # -- Gradio callbacks ------------------------------------------------------- | |
| def new_round(dataset_name: str, rating_pref: str, state: dict) -> tuple: | |
| cfg = DATASETS[dataset_name] | |
| groups = _select_groups(cfg, rating_pref) | |
| assert groups, f"No groups for rating preference: {rating_pref}" | |
| group = random.choice(groups) | |
| pair_data = cfg["fetch_pair"](cfg["groups"][group]) | |
| if len(pair_data) == 3: | |
| row_a, row_b, pair_reason = pair_data | |
| else: | |
| row_a, row_b = pair_data | |
| pair_reason = "" | |
| state.setdefault("session_id", uuid.uuid4().hex) | |
| key_a = cfg["get_id"](row_a) | |
| key_b = cfg["get_id"](row_b) | |
| id_a = int(row_a["id"]) | |
| id_b = int(row_b["id"]) | |
| state.update(dataset=dataset_name, rating_pref=rating_pref, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason) | |
| url_a = cfg["get_image"](row_a) | |
| url_b = cfg["get_image"](row_b) | |
| state["url_a"] = url_a | |
| state["url_b"] = url_b | |
| return _render_current(state) | |
| def _queue_decision(winner: str | None, state: dict): | |
| assert state.get("session_id"), "Missing session_id: refusing to record vote" | |
| state.setdefault("pending", []) | |
| decision = { | |
| "winner": winner, | |
| "key_a": state["key_a"], | |
| "key_b": state["key_b"], | |
| "id_a": state["id_a"], | |
| "id_b": state["id_b"], | |
| "url_a": state["url_a"], | |
| "url_b": state["url_b"], | |
| "dataset": state["dataset"], | |
| "rating_pref": state["rating_pref"], | |
| "group": state["group"], | |
| "pair_reason": state.get("pair_reason", ""), | |
| "session_id": state["session_id"], | |
| } | |
| state["pending"].append(decision) | |
| state["last_decision"] = decision | |
| state["can_go_back"] = True | |
| _commit_oldest_pending(state) | |
| def vote(winner: str | None, state: dict, submit_key: str | None) -> tuple: | |
| assert winner in ("A", "B", None) | |
| if _normalize_submit_key(submit_key) != SUBMIT_KEY: | |
| return _render_current(state, "Wrong submission key.") | |
| _queue_decision(winner, state) | |
| return new_round(state["dataset"], state["rating_pref"], state) | |
| def go_back(state: dict) -> tuple: | |
| pending = state.setdefault("pending", []) | |
| if not state.get("can_go_back"): | |
| return _render_current(state) | |
| last = state.get("last_decision") | |
| if not last: | |
| state["can_go_back"] = False | |
| return _render_current(state) | |
| if pending and pending[-1] == last: | |
| pending.pop() | |
| state["can_go_back"] = False | |
| state["last_decision"] = None | |
| state.update( | |
| dataset=last["dataset"], | |
| rating_pref=last["rating_pref"], | |
| key_a=last["key_a"], | |
| key_b=last["key_b"], | |
| id_a=last["id_a"], | |
| id_b=last["id_b"], | |
| url_a=last["url_a"], | |
| url_b=last["url_b"], | |
| group=last["group"], | |
| pair_reason=last.get("pair_reason", ""), | |
| ) | |
| return _render_current(state) | |
| # -- UI --------------------------------------------------------------------- | |
| with gr.Blocks( | |
| title="Image Rater", | |
| head=""" | |
| <script> | |
| const VOTE_COOLDOWN_MS = 1500; | |
| let lastVoteAtMs = 0; | |
| let voteToastTimer = null; | |
| function showVoteToast(message) { | |
| let toast = document.getElementById('vote-cooldown-toast'); | |
| if (!toast) { | |
| toast = document.createElement('div'); | |
| toast.id = 'vote-cooldown-toast'; | |
| toast.style.position = 'fixed'; | |
| toast.style.left = '50%'; | |
| toast.style.bottom = '20px'; | |
| toast.style.transform = 'translateX(-50%)'; | |
| toast.style.background = 'rgba(20, 20, 20, 0.92)'; | |
| toast.style.color = '#fff'; | |
| toast.style.padding = '8px 12px'; | |
| toast.style.borderRadius = '8px'; | |
| toast.style.fontSize = '0.92rem'; | |
| toast.style.zIndex = '9999'; | |
| toast.style.pointerEvents = 'none'; | |
| toast.style.opacity = '0'; | |
| toast.style.transition = 'opacity 120ms ease'; | |
| document.body.appendChild(toast); | |
| } | |
| toast.textContent = message; | |
| toast.style.opacity = '1'; | |
| if (voteToastTimer) clearTimeout(voteToastTimer); | |
| voteToastTimer = setTimeout(function () { | |
| toast.style.opacity = '0'; | |
| }, 1400); | |
| } | |
| function showVoteBlockedMessage(remainingMs) { | |
| const remainingS = Math.max(0.1, remainingMs / 1000).toFixed(1); | |
| showVoteToast(`Please wait ${remainingS}s before submitting again.`); | |
| } | |
| function findVoteButtonTarget(target) { | |
| return target?.closest('#btn-vote-a button, button#btn-vote-a, #btn-vote-b button, button#btn-vote-b, #btn-skip button, button#btn-skip'); | |
| } | |
| function clearImageContainers() { | |
| const leftImg = document.querySelector('#img-a img'); | |
| const rightImg = document.querySelector('#img-b img'); | |
| if (leftImg) { | |
| leftImg.src = ''; | |
| leftImg.removeAttribute('srcset'); | |
| } | |
| if (rightImg) { | |
| rightImg.src = ''; | |
| rightImg.removeAttribute('srcset'); | |
| } | |
| } | |
| function isVisible(el) { | |
| return !!(el && el.offsetParent !== null); | |
| } | |
| window.addEventListener('keydown', function (e) { | |
| const t = e.target; | |
| const voteAButton = document.querySelector('#btn-vote-a button, button#btn-vote-a'); | |
| const voteBButton = document.querySelector('#btn-vote-b button, button#btn-vote-b'); | |
| const skipButton = document.querySelector('#btn-skip button, button#btn-skip'); | |
| const backButton = document.querySelector('#btn-back-action button, button#btn-back-action'); | |
| const resultsLoadMoreButton = document.querySelector('#btn-results-load-more button, button#btn-results-load-more'); | |
| const ratingTabActive = isVisible(voteAButton) || isVisible(voteBButton) || isVisible(skipButton); | |
| const resultsTabActive = isVisible(resultsLoadMoreButton); | |
| if (t && (t.tagName === 'INPUT' || t.tagName === 'TEXTAREA' || t.isContentEditable)) return; | |
| if (e.key === 'ArrowLeft' && ratingTabActive) { | |
| e.preventDefault(); | |
| voteAButton?.click(); | |
| } else if (e.key === 'ArrowRight' && ratingTabActive) { | |
| e.preventDefault(); | |
| voteBButton?.click(); | |
| } else if (e.key === 'ArrowUp' && ratingTabActive) { | |
| e.preventDefault(); | |
| skipButton?.click(); | |
| } else if ((e.key === 'z' || e.key === 'Z') && (e.ctrlKey || e.metaKey) && ratingTabActive) { | |
| e.preventDefault(); | |
| backButton?.click(); | |
| } else if (e.key === 'ArrowDown' && resultsTabActive) { | |
| e.preventDefault(); | |
| resultsLoadMoreButton?.click(); | |
| } | |
| }); | |
| document.addEventListener('click', function (e) { | |
| const voteBtn = findVoteButtonTarget(e.target); | |
| if (voteBtn) { | |
| const nowMs = Date.now(); | |
| const elapsedMs = nowMs - lastVoteAtMs; | |
| if (elapsedMs < VOTE_COOLDOWN_MS) { | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| if (typeof e.stopImmediatePropagation === 'function') e.stopImmediatePropagation(); | |
| showVoteBlockedMessage(VOTE_COOLDOWN_MS - elapsedMs); | |
| return; | |
| } | |
| lastVoteAtMs = nowMs; | |
| clearImageContainers(); | |
| return; | |
| } | |
| const a = e.target.closest('a[href="#back"]'); | |
| if (!a) return; | |
| e.preventDefault(); | |
| document.querySelector('#btn-back-action button, button#btn-back-action')?.click(); | |
| }, true); | |
| </script> | |
| """, | |
| css=""" | |
| .subtle-link button { | |
| background: none !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| color: #7a7a7a !important; | |
| text-decoration: underline !important; | |
| padding: 0 !important; | |
| min-height: 0 !important; | |
| font-size: 0.9em !important; | |
| justify-content: flex-start !important; | |
| } | |
| .subtle-link button:hover { | |
| color: #5a5a5a !important; | |
| } | |
| .subtle-link { | |
| width: fit-content !important; | |
| } | |
| .subtle-link button { | |
| width: fit-content !important; | |
| } | |
| .subtle-note { | |
| color: #888; | |
| font-size: 0.9em; | |
| } | |
| .rating-card { | |
| width: 100%; | |
| } | |
| .rating-card-title { | |
| min-height: 24px; | |
| margin-bottom: 8px; | |
| } | |
| .rating-image-frame { | |
| width: 100%; | |
| height: 512px; | |
| border: 1px solid #e6e6e6; | |
| border-radius: 8px; | |
| background: #333; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| overflow: hidden; | |
| } | |
| .rating-image { | |
| width: auto !important; | |
| height: auto !important; | |
| max-width: 100% !important; | |
| max-height: 100% !important; | |
| object-fit: contain !important; | |
| object-position: center center !important; | |
| display: block; | |
| } | |
| .subtle-back-link-wrap a { | |
| color: #7a7a7a !important; | |
| text-decoration: underline; | |
| } | |
| .subtle-back-link-wrap a:hover { | |
| color: #5a5a5a !important; | |
| } | |
| .subtle-back-link-disabled { | |
| color: #b8b8b8 !important; | |
| pointer-events: none; | |
| text-decoration: none; | |
| } | |
| .hidden-action-btn { | |
| display: none !important; | |
| } | |
| #submit-status { | |
| position: fixed; | |
| left: 50%; | |
| bottom: 20px; | |
| transform: translateX(-50%); | |
| z-index: 9998; | |
| pointer-events: none; | |
| min-height: 1.2em; | |
| } | |
| .submit-status-msg { | |
| background: rgba(20, 20, 20, 0.92); | |
| color: #fff; | |
| padding: 8px 12px; | |
| border-radius: 8px; | |
| font-size: 0.92rem; | |
| } | |
| #results-gallery { | |
| --explorer-thumb-ratio: 1 / 1; | |
| } | |
| #results-gallery button, | |
| #results-gallery .thumbnail-item { | |
| aspect-ratio: var(--explorer-thumb-ratio) !important; | |
| } | |
| #results-gallery img { | |
| width: 100% !important; | |
| height: 100% !important; | |
| object-fit: contain !important; | |
| background: #1f2937; | |
| } | |
| """, | |
| ) as demo: | |
| state = gr.State({}) | |
| rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref") | |
| submit_key_store = gr.BrowserState(default_value="", storage_key="submit_key") | |
| results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode") | |
| results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier") | |
| with gr.Tabs(): | |
| with gr.Tab("Image Quality Rater"): | |
| gr.Markdown("Rate relative image quality. Choose the image with better quality, or select same quality if they are comparable. Both images are drawn from the same group to avoid cross-group bias.") | |
| with gr.Row(): | |
| img_a = gr.HTML(elem_id="img-a") | |
| img_b = gr.HTML(elem_id="img-b") | |
| with gr.Row(): | |
| btn_a = gr.Button("👍 Prefer A", variant="primary", elem_id="btn-vote-a") | |
| btn_skip = gr.Button("Same quality", elem_id="btn-skip") | |
| btn_b = gr.Button("👍 Prefer B", variant="primary", elem_id="btn-vote-b") | |
| with gr.Accordion("Settings", open=False): | |
| gr.Markdown("<span style='color:#888;font-size:0.9em;'>Advanced options</span>") | |
| rating_dd = gr.Dropdown( | |
| choices=["safe", "all"], | |
| value="safe", | |
| label="Rating", | |
| elem_id="rating-pref", | |
| ) | |
| submit_key_tb = gr.Textbox( | |
| value="", | |
| type="password", | |
| label="Submit key", | |
| elem_id="submit-key", | |
| ) | |
| link_a = gr.Markdown(label="Image A link") | |
| link_b = gr.Markdown(label="Image B link") | |
| back_link = gr.Markdown(elem_classes=["subtle-back-link-wrap"]) | |
| btn_back_action = gr.Button("Undo Rating (Ctrl+z)", elem_id="btn-back-action", elem_classes=["hidden-action-btn"]) | |
| details_md = gr.Markdown() | |
| pair_reason_md = gr.Markdown() | |
| submit_status_md = gr.Markdown(elem_id="submit-status") | |
| gr.Markdown("<span class='subtle-note'>Dataset: <a href='https://huggingface.co/datasets/taigasan/e6-visual-ratings' target='_blank' rel='noopener noreferrer'>taigasan/e6-visual-ratings</a></span>") | |
| gr.Markdown("<span class='subtle-note'>Keyboard Shortcuts: ⬅️ vote A, ⬆️ same quality, ➡️ vote B, Ctrl+z undo rating</span>") | |
| ( | |
| results_summary_md, | |
| results_sort_dd, | |
| results_classifier_dd, | |
| results_score_distribution_plot, | |
| results_distribution_state, | |
| results_gallery, | |
| results_load_more_btn, | |
| selected_image_md, | |
| results_page_meta_state, | |
| results_page_offset_state, | |
| ) = add_results_tab(_pool_df) | |
| outputs = [img_a, img_b, link_a, link_b, back_link, details_md, pair_reason_md, submit_status_md, state] | |
| results_outputs = [ | |
| results_summary_md, | |
| results_score_distribution_plot, | |
| results_distribution_state, | |
| results_gallery, | |
| results_load_more_btn, | |
| selected_image_md, | |
| results_page_meta_state, | |
| results_page_offset_state, | |
| ] | |
| btn_a.click(fn=lambda s, k: vote("A", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") | |
| btn_b.click(fn=lambda s, k: vote("B", s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") | |
| btn_skip.click(fn=lambda s, k: vote(None, s, k), inputs=[state, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") | |
| btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden") | |
| rating_dd.change(fn=_on_rating_change, inputs=[rating_dd, state], outputs=[*outputs, rating_pref_store], queue=False, show_progress="hidden") | |
| submit_key_tb.input(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden") | |
| submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden") | |
| rating_dd.change(fn=_load_results, inputs=[rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") | |
| results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden") | |
| results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") | |
| results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden") | |
| results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden") | |
| demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store], outputs=[rating_dd, submit_key_tb, *outputs], queue=False, show_progress="hidden") | |
| demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") | |
| demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden") | |
| demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden") | |
| results_load_more_btn.click( | |
| fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o), | |
| inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state], | |
| outputs=[results_gallery, results_page_meta_state, results_page_offset_state, results_load_more_btn], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| results_gallery.select( | |
| fn=on_gallery_select, | |
| inputs=[results_page_meta_state, results_distribution_state], | |
| outputs=[selected_image_md, results_score_distribution_plot], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |