import gradio as gr import random import threading import time import uuid import os import html import sys from typing import Callable import pandas as pd from huggingface_hub import hf_hub_download from storage import VoteStorage from stats_from_logs import load_stats_by_md5 from explorer import ALLOWED_CLASSIFIER_FILTERS, add_results_tab, build_results_data, load_more_results, on_gallery_select DEBUG_MODE = os.getenv("DEBUG", "0").lower() in ("1", "true", "yes", "on") RATINGS_APP_TOKEN = os.getenv("RATINGS_APP_TOKEN") SUBMIT_KEY = os.getenv("RATINGS_SUBMIT_KEY") assert SUBMIT_KEY, "Missing required env var: RATINGS_SUBMIT_KEY" POOL_REPO_ID = "taigasan/e6-visual-ratings" VOTE_STORAGE = VoteStorage(mode="void" if DEBUG_MODE else "hf", token=RATINGS_APP_TOKEN) STATS_RELOAD_S = 30 * 60 # -- Pool dataset ----------------------------------------------------------- _pool_path = hf_hub_download( repo_id=POOL_REPO_ID, filename="pool.parquet", repo_type="dataset", token=RATINGS_APP_TOKEN ) _pool_df = pd.read_parquet(_pool_path) _pool_df[["wins", "losses", "ties", "votes", "winrate"]] = (0, 0, 0, 0, 0.0) WINS_LOC = _pool_df.columns.get_loc("wins") LOSSES_LOC = _pool_df.columns.get_loc("losses") TIES_LOC = _pool_df.columns.get_loc("ties") VOTES_LOC = _pool_df.columns.get_loc("votes") WINRATE_LOC = _pool_df.columns.get_loc("winrate") _md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) } _pool_lock = threading.Lock() _stats_last_loaded_at = 0.0 _explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"]) def _load_stats() -> None: VOTE_STORAGE.sync() load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN) n_missing = 0 with _pool_lock: VOTE_STORAGE.sync() stats_by_key = load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN) for md5, stats in stats_by_key.items(): if (idx := _md5_to_idx.get(md5)) is not None: _pool_df.iloc[idx, [WINS_LOC, LOSSES_LOC, TIES_LOC, VOTES_LOC, WINRATE_LOC]] = ( stats.wins, stats.losses, stats.ties, stats.votes, stats.winrate ) else: n_missing += 1 if n_missing: print(f"{n_missing} md5s have stats but are not in the pool!", file=sys.stderr) classifier_scores_path = hf_hub_download( repo_id=POOL_REPO_ID, filename="classifier_scores.parquet", repo_type="dataset", token=RATINGS_APP_TOKEN, ) validation_set_path = hf_hub_download( repo_id=POOL_REPO_ID, filename="validation_set.parquet", repo_type="dataset", token=RATINGS_APP_TOKEN, ) validation_df = pd.read_parquet( validation_set_path, columns=["group", "id", "md5", "rating", "sample_url", "image_url"], ) classifier_scores_df = pd.read_parquet(classifier_scores_path) assert {"classifier", "md5", "classifier_score", "percentile"}.issubset(classifier_scores_df.columns), "classifier_scores.parquet missing expected columns" classifier_scores_df = classifier_scores_df[["classifier", "md5", "classifier_score", "percentile"]] classifier_scores_df["classifier"] = classifier_scores_df["classifier"].astype(str) classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str) validation_df["md5"] = validation_df["md5"].astype(str) global _explorer_df _explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many") def _stats_reloader() -> None: while True: time.sleep(STATS_RELOAD_S) _load_stats() _load_stats() threading.Thread(target=_stats_reloader, daemon=True).start() def _pick_from(df: pd.DataFrame, *, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None: if len(df) < 2: return None sample = df.sample(2, weights=weights, replace=False) return sample.iloc[0], sample.iloc[1], len(df) def _pick_similar( df: pd.DataFrame, distance: Callable[[pd.DataFrame, pd.Series], pd.Series], *, weights: Callable[[pd.DataFrame], pd.Series] | None = None, other_df: pd.DataFrame | None = None, ) -> tuple[pd.Series, pd.Series, int] | None: if len(df) < 2: return None if other_df is None: other_df = df elif len(other_df) < 2: return None weight_vals: pd.Series | None = None if weights is not None: weight_vals = weights(df) first = df.sample(weights=weight_vals).iloc[0] weight_vals = 1.0 / (1.0 + distance(other_df, first)) while True: other = other_df.sample(weights=weight_vals).iloc[0] if other["md5"] != first["md5"]: return first, other, len(df) def _pool_fetch_pair(group: str) -> tuple[pd.Series, pd.Series, int, str]: gdf = _pool_df[_pool_df["group"] == group] voted = gdf[gdf["votes"] > 0] votes = voted["votes"] # Pair first-time winners. picked = _pick_from(voted[(votes == 1) & (voted["wins"] == 1)]) if picked is not None: return *picked, "new-winners" # Pair first-time losers. picked = _pick_from(voted[(votes == 1) & (voted["losses"] == 1)]) if picked is not None: return *picked, "new-losers" def record_distance(df: pd.DataFrame, pivot: pd.Series) -> pd.Series: return ( (df["wins"] - pivot["wins"])**2 + (df["losses"] - pivot["losses"])**2 )**0.75 # L2 is a bit too loose # Link cliques to main network and break ties. nonties = votes - voted["ties"] picked = _pick_similar( voted[(nonties == 0) | (votes == 2)], record_distance, other_df=voted[nonties > 3], ) if picked is not None: return *picked, "sparse" # Introduce new images. if len(voted) < 8 or random.random() < 0.33: unvoted = gdf[gdf["votes"] == 0] match len(unvoted): case 0: pass case 1: return unvoted.iloc[0], voted.iloc[0], 1, "new" case _: picked = _pick_from(unvoted) assert picked is not None return *picked, "new" # Vote-weighted random sampling between similar winrates, slighlty biased against picking losers. picked = _pick_similar( voted, record_distance, weights=lambda df: 1.0 / (df["votes"]**1.25 + 0.1 * df["losses"]), ) assert picked is not None return *picked, "fair-probe" def _row_image_url(row) -> str: sample_url = row.get("sample_url") if isinstance(sample_url, str) and sample_url: return sample_url image_url = row.get("image_url") if isinstance(image_url, str) and image_url: return image_url return '' DATASETS: dict[str, dict] = { "pool": { "fetch_pair": _pool_fetch_pair, "get_id": lambda row: row["md5"], "get_image": _row_image_url, "groups": sorted(_pool_df["group"].unique()), }, } DEFAULT_DATASET = list(DATASETS.keys())[0] def _format_rating_post_title(post_id: int, votes: int, label: str) -> str: return f"{label}: Post #{post_id} | {votes} {'Vote' if votes == 1 else 'Votes'}" def _render_current(state: dict, submit_status: str = "") -> tuple: votes_a = _pool_df.iloc[_md5_to_idx[state["key_a"]], VOTES_LOC] votes_b = _pool_df.iloc[_md5_to_idx[state["key_b"]], VOTES_LOC] title_a = _format_rating_post_title(state["id_a"], votes_a, "Image A") title_b = _format_rating_post_title(state["id_b"], votes_b, "Image B") img_a_html = f"
{title_a}
" img_b_html = f"
{title_b}
" can_go_back = bool(state.get("pending", ())) pair_details = f"/ {state['group']} / {state.get('pair_reason', 'unknown')}" return img_a_html, img_b_html, gr.Button(interactive=can_go_back), html.escape(pair_details), html.escape(submit_status), state def _normalize_rating_pref(pref: str | None) -> str: return pref if pref in ("safe", "all") else "safe" def _initial_load(state: dict, rating_pref: str | None, submit_key: str | None, image_height: str, groups: list[str]): rating_pref = _normalize_rating_pref(rating_pref) submit_key = _normalize_submit_key(submit_key) return rating_pref, submit_key, image_height, image_height, groups, *new_round(DEFAULT_DATASET, groups, state) def _on_groups_change(groups: list[str], state: dict): return *new_round(DEFAULT_DATASET, groups, state), groups def _on_image_height_change(image_height: str) -> tuple[str, str]: return image_height, image_height def _normalize_submit_key(submit_key: str | None) -> str: return (submit_key or "").strip() def _filtered_explorer_df(rating_pref: str) -> pd.DataFrame: return _filtered_explorer_df_by_classifier(rating_pref, ALLOWED_CLASSIFIER_FILTERS[0]) def _filtered_explorer_df_by_classifier(rating_pref: str, classifier_name: str) -> pd.DataFrame: if rating_pref == "all": rating_filtered = _explorer_df else: assert rating_pref == "safe", f"Unsupported rating preference: {rating_pref}" rating_filtered = _explorer_df[_explorer_df["rating"] == "s"] assert classifier_name in ALLOWED_CLASSIFIER_FILTERS, f"Unsupported classifier filter: {classifier_name}" return rating_filtered[rating_filtered["classifier"] == classifier_name] def _load_results(rating_pref: str, sort_mode: str, classifier_filter: str): rating_pref = _normalize_rating_pref(rating_pref) sort_mode = _normalize_sort_mode(sort_mode) classifier_name = _normalize_classifier_filter(classifier_filter) filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name) summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data( filtered_explorer_df, _explorer_df, rating_pref, sort_mode, classifier_name, ) return summary, score_distribution_plot, distribution_data, gallery_items, btn_update, "Click an image to reveal its ID and link.", page_meta, next_offset def _normalize_sort_mode(sort_mode: str | None) -> str: if sort_mode in ("Default", "Rating: Low to High", "Rating: High to Low"): return sort_mode return "Default" def _normalize_classifier_filter(classifier_name: str | None) -> str: if classifier_name in ALLOWED_CLASSIFIER_FILTERS: return str(classifier_name) return ALLOWED_CLASSIFIER_FILTERS[0] # -- Gradio callbacks ------------------------------------------------------- def new_round(dataset_name: str, groups: list[str], state: dict) -> tuple: if not groups: return "", "", gr.skip(), "", "Please select at least one group.", state cfg = DATASETS[dataset_name] group = random.choice(groups) row_a, row_b, reason_remaining, pair_reason = cfg["fetch_pair"](group) pair_reason = f"{pair_reason} ({reason_remaining})" state.setdefault("session_id", uuid.uuid4().hex) key_a = cfg["get_id"](row_a) key_b = cfg["get_id"](row_b) id_a = int(row_a["id"]) id_b = int(row_b["id"]) state.update(dataset=dataset_name, key_a=key_a, key_b=key_b, id_a=id_a, id_b=id_b, group=group, pair_reason=pair_reason) url_a = cfg["get_image"](row_a) url_b = cfg["get_image"](row_b) state["url_a"] = url_a state["url_b"] = url_b return _render_current(state) def _queue_decision(winner: str | None, state: dict): assert state.get("session_id"), "Missing session_id: refusing to record vote" pending = state.setdefault("pending", []) pending.append({ "winner": winner, "key_a": state["key_a"], "key_b": state["key_b"], "id_a": state["id_a"], "id_b": state["id_b"], "url_a": state["url_a"], "url_b": state["url_b"], "dataset": state["dataset"], "group": state["group"], "pair_reason": state.get("pair_reason", ""), "session_id": state["session_id"], }) if len(pending) > 1: VOTE_STORAGE.queue_row(pending.pop(0)) def _add_vote(idx: int, col_loc: int, delta: int = 1) -> None: _pool_df.iloc[idx, [col_loc, VOTES_LOC]] += delta wins, ties, votes = _pool_df.iloc[idx, [WINS_LOC, TIES_LOC, VOTES_LOC]] _pool_df.iloc[idx, WINRATE_LOC] = (wins + 0.5 * ties) / max(votes, 1) def vote(winner: str | None, state: dict, groups: list[str], submit_key: str | None) -> tuple: if _normalize_submit_key(submit_key) != SUBMIT_KEY: return _render_current(state, "Wrong submission key.") if not groups: return "", "", gr.skip(), "", "Please select at least one group.", state _queue_decision(winner, state) a_idx = _md5_to_idx[state["key_a"]] b_idx = _md5_to_idx[state["key_b"]] with _pool_lock: match winner: case "A": _add_vote(a_idx, WINS_LOC) _add_vote(b_idx, LOSSES_LOC) case "B": _add_vote(a_idx, LOSSES_LOC) _add_vote(b_idx, WINS_LOC) case None: _add_vote(a_idx, TIES_LOC) _add_vote(b_idx, TIES_LOC) case _: raise AssertionError return new_round(state["dataset"], groups, state) def go_back(state: dict) -> tuple: pending = state.setdefault("pending", []) if pending: last = pending.pop() state.update( dataset=last["dataset"], key_a=last["key_a"], key_b=last["key_b"], id_a=last["id_a"], id_b=last["id_b"], url_a=last["url_a"], url_b=last["url_b"], group=last["group"], pair_reason=last.get("pair_reason", ""), ) a_idx = _md5_to_idx[state["key_a"]] b_idx = _md5_to_idx[state["key_b"]] with _pool_lock: match last["winner"]: case "A": _add_vote(a_idx, WINS_LOC, -1) _add_vote(b_idx, LOSSES_LOC, -1) case "B": _add_vote(a_idx, LOSSES_LOC, -1) _add_vote(b_idx, WINS_LOC, -1) case None: _add_vote(a_idx, TIES_LOC, -1) _add_vote(b_idx, TIES_LOC, -1) case _: raise AssertionError return _render_current(state) # -- UI --------------------------------------------------------------------- with gr.Blocks( title="e621 Visual Ratings", head=""" """, css=""" .subtle-link button { background: none !important; border: none !important; box-shadow: none !important; color: #7a7a7a !important; text-decoration: underline !important; padding: 0 !important; min-height: 0 !important; font-size: 0.9em !important; justify-content: flex-start !important; } .subtle-link button:hover { color: #5a5a5a !important; } .subtle-link { width: fit-content !important; } .subtle-link button { width: fit-content !important; } .subtle-note { color: #888; font-size: 0.9em; } .rating-card { width: 100%; } .rating-card-title { min-height: 24px; margin-bottom: 8px; } .rating-image-frame { width: 100%; border: 1px solid #e6e6e6; border-radius: 8px; background: #333; display: flex; align-items: center; justify-content: center; overflow: hidden; } .rating-image { width: auto !important; height: auto !important; max-width: 100% !important; max-height: 100% !important; object-fit: contain !important; object-position: center center !important; display: block; } .subtle-back-link-wrap a { color: #7a7a7a !important; text-decoration: underline; } .subtle-back-link-wrap a:hover { color: #5a5a5a !important; } .subtle-back-link-disabled { color: #b8b8b8 !important; pointer-events: none; text-decoration: none; } .hidden-action-btn { display: none !important; } #submit-status { position: fixed; left: 50%; bottom: 20px; transform: translateX(-50%); z-index: 9998; pointer-events: none; min-height: 1.2em; } .submit-status-msg { background: rgba(20, 20, 20, 0.92); color: #fff; padding: 8px 12px; border-radius: 8px; font-size: 0.92rem; } #results-gallery { --explorer-thumb-ratio: 1 / 1; } #results-gallery button, #results-gallery .thumbnail-item { aspect-ratio: var(--explorer-thumb-ratio) !important; } #results-gallery img { width: 100% !important; height: 100% !important; object-fit: contain !important; background: #1f2937; } a { padding: 0 !important; } """, fill_width=True, ) as demo: state = gr.State({}) rating_pref_store = gr.BrowserState(default_value="safe", storage_key="rating_pref") submit_key_store = gr.BrowserState(default_value="", storage_key="submit_key") results_sort_store = gr.BrowserState(default_value="Default", storage_key="results_sort_mode") results_classifier_store = gr.BrowserState(default_value=ALLOWED_CLASSIFIER_FILTERS[0], storage_key="results_classifier") image_height_store = gr.BrowserState(default_value=768, storage_key="image_height") groups_store = gr.BrowserState(default_value=[ group for group in DATASETS[DEFAULT_DATASET]["groups"] if group.endswith("_safe") ], storage_key="groups") with gr.Tabs(): with gr.Tab("Image Quality Rater"): gr.Markdown("Rate relative image quality. Choose the image with better quality, or select same quality if they are comparable. Both images are drawn from the same group to avoid cross-group bias.") with gr.Row(): img_a = gr.HTML(elem_id="img-a") img_b = gr.HTML(elem_id="img-b") with gr.Row(equal_height=True): btn_a = gr.Button("⬅️ Prefer A", variant="primary", elem_id="btn-vote-a") with gr.Column(scale=0), gr.Group(): btn_skip = gr.Button("⬆️ Same Quality", elem_id="btn-skip") btn_back_action = gr.Button("⬇️ Undo", elem_id="btn-back-action") btn_b = gr.Button("➡️ Prefer B", variant="primary", elem_id="btn-vote-b") with gr.Accordion("Settings", open=False): groups_select = gr.CheckboxGroup( choices=DATASETS[DEFAULT_DATASET]["groups"], label="Categories", show_label=True, show_select_all=True ) image_height_slider = gr.Slider( minimum=512, maximum=2048, step=16, precision=0, label="Image Size", ) submit_key_tb = gr.Textbox( value="", type="password", label="Submit Key", elem_id="submit-key", ) pair_details = gr.HTML(html_template="Dataset: taigasan/e6-visual-ratings ${value}") submit_status = gr.HTML(html_template="${value}") gr.HTML("Keyboard Shortcuts: ⬅️ Vote A, ⬆️ Same Quality, ➡️ Vote B, ⬇️ or Ctrl+Z Undo") image_height = gr.HTML(html_template="", apply_default_css=False) ( results_summary_md, results_rating_dd, results_sort_dd, results_classifier_dd, results_score_distribution_plot, results_distribution_state, results_gallery, results_load_more_btn, selected_image_md, results_page_meta_state, results_page_offset_state, ) = add_results_tab(_pool_df) outputs = [img_a, img_b, btn_back_action, pair_details, submit_status, state] results_outputs = [ results_summary_md, results_score_distribution_plot, results_distribution_state, results_gallery, results_load_more_btn, selected_image_md, results_page_meta_state, results_page_offset_state, ] btn_a.click(fn=lambda s, g, k: vote("A", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") btn_b.click(fn=lambda s, g, k: vote("B", s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") btn_skip.click(fn=lambda s, g, k: vote(None, s, g, k), inputs=[state, groups_store, submit_key_store], outputs=outputs, queue=False, show_progress="hidden") btn_back_action.click(fn=go_back, inputs=[state], outputs=outputs, queue=False, show_progress="hidden") submit_key_tb.change(fn=_normalize_submit_key, inputs=[submit_key_tb], outputs=[submit_key_store], queue=False, show_progress="hidden") groups_select.change(fn=_on_groups_change, inputs=[groups_select, state], outputs=[*outputs, groups_store], queue=False, show_progress="hidden") image_height_slider.change(fn=_on_image_height_change, inputs=[image_height_slider], outputs=[image_height_store, image_height], queue=False, show_progress="hidden") results_rating_dd.change(fn=_normalize_rating_pref, inputs=[results_rating_dd], outputs=[rating_pref_store], queue=False, show_progress="hidden") results_rating_dd.change(fn=_load_results, inputs=[results_rating_dd, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") results_sort_dd.change(fn=_normalize_sort_mode, inputs=[results_sort_dd], outputs=[results_sort_store], queue=False, show_progress="hidden") results_sort_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_dd, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") results_classifier_dd.change(fn=_normalize_classifier_filter, inputs=[results_classifier_dd], outputs=[results_classifier_store], queue=False, show_progress="hidden") results_classifier_dd.change(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_dd], outputs=results_outputs, queue=False, show_progress="hidden") demo.load(fn=_initial_load, inputs=[state, rating_pref_store, submit_key_store, image_height_store, groups_store], outputs=[results_rating_dd, submit_key_tb, image_height_slider, image_height, groups_select, *outputs], queue=False, show_progress="hidden") demo.load(fn=_load_results, inputs=[rating_pref_store, results_sort_store, results_classifier_store], outputs=results_outputs, queue=False, show_progress="hidden") demo.load(fn=_normalize_sort_mode, inputs=[results_sort_store], outputs=[results_sort_dd], queue=False, show_progress="hidden") demo.load(fn=_normalize_classifier_filter, inputs=[results_classifier_store], outputs=[results_classifier_dd], queue=False, show_progress="hidden") results_load_more_btn.click( fn=lambda r, s, c, o: load_more_results(_filtered_explorer_df_by_classifier(_normalize_rating_pref(r), _normalize_classifier_filter(c)), _explorer_df, s, o), inputs=[rating_pref_store, results_sort_store, results_classifier_store, results_page_offset_state], outputs=[results_gallery, results_page_meta_state, results_page_offset_state, results_load_more_btn], queue=False, show_progress="hidden", ) results_gallery.select( fn=on_gallery_select, inputs=[results_page_meta_state, results_distribution_state], outputs=[selected_image_md, results_score_distribution_plot], queue=False, show_progress="hidden", ) if __name__ == "__main__": demo.launch()