e6-visual-ratings / explorer.py
RedHotTensors's picture
Allow picking from individual image groups.
2e679fc
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import html
from matplotlib.figure import Figure
PAGE_SIZE = 30
SORT_MODES = ["Default", "Rating: Low to High", "Rating: High to Low"]
ALLOWED_CLASSIFIER_FILTERS = ("SiglipFinetune", "ScoringHead-v5", "TaggerExperiment")
def _row_image_url(row) -> str | None:
sample_url = row.get("sample_url")
if isinstance(sample_url, str) and sample_url:
return sample_url
image_url = row.get("image_url")
if isinstance(image_url, str) and image_url:
return image_url
return None
def _gallery_items(meta: list[dict]) -> list[tuple[str, str]]:
return [
(
str(item["url"]),
f"Score: {float(item['classifier_score']):.4f} | Percentile: {int(item['percentile'])}",
)
for item in meta
]
def _distribution_data(
explorer_df: pd.DataFrame,
) -> dict[str, list[float] | int]:
valid_scores = [
float(score)
for score in explorer_df["classifier_score"].tolist()
if pd.notna(score)
]
if not valid_scores:
return {"bin_edges": [], "counts": [], "total": 0}
counts, bin_edges = np.histogram(valid_scores, bins=40)
return {
"bin_edges": bin_edges.astype(float).tolist(),
"counts": counts.astype(float).tolist(),
"total": int(len(valid_scores)),
}
def _classifier_score_distribution_plot(
distribution_data: dict[str, list[float] | int],
selected_score: float | None = None,
) -> Figure:
fig, ax = plt.subplots(figsize=(6, 2.2))
fig.patch.set_facecolor("#0f1117")
ax.set_facecolor("#151922")
bin_edges = np.asarray(distribution_data.get("bin_edges", []), dtype=float)
counts = np.asarray(distribution_data.get("counts", []), dtype=float)
if counts.size > 0 and bin_edges.size == counts.size + 1:
widths = np.diff(bin_edges)
ax.bar(bin_edges[:-1], counts, width=widths, align="edge", color="#3b82f6", alpha=0.9, edgecolor="#93c5fd", linewidth=0.35)
ax.set_ylabel("Count", color="#e5e7eb")
if selected_score is not None:
ax.axvline(float(selected_score), color="#f97316", linewidth=2.0)
else:
ax.text(0.5, 0.5, "No classifier scores available.", ha="center", va="center", transform=ax.transAxes, color="#e5e7eb")
ax.set_yticks([])
ax.set_title("Classifier Score Distribution", color="#f3f4f6")
ax.set_xlabel("Classifier score", color="#e5e7eb")
ax.tick_params(colors="#d1d5db")
for spine in ax.spines.values():
spine.set_color("#4b5563")
ax.grid(axis="y", color="#374151", alpha=0.4)
fig.tight_layout()
return fig
def _build_page_meta(
explorer_df: pd.DataFrame,
all_explorer_df: pd.DataFrame,
sort_mode: str,
offset: int,
) -> tuple[list[dict[str, str | int]], int, bool, int]:
filtered = explorer_df[["id", "md5", "sample_url", "image_url", "classifier_score", "percentile"]].copy()
has_sample = filtered["sample_url"].notna() & (filtered["sample_url"] != "")
has_image = filtered["image_url"].notna() & (filtered["image_url"] != "")
filtered = filtered[has_sample | has_image]
filtered = filtered[filtered["classifier_score"].notna() & filtered["percentile"].notna()]
if sort_mode == "Rating: Low to High":
filtered = filtered.sort_values("classifier_score", ascending=True, kind="mergesort")
elif sort_mode == "Rating: High to Low":
filtered = filtered.sort_values("classifier_score", ascending=False, kind="mergesort")
page_df = filtered.iloc[offset:offset + PAGE_SIZE]
all_scores = all_explorer_df[["md5", "classifier", "classifier_score", "percentile"]].copy()
all_scores = all_scores[
all_scores["classifier"].notna()
& all_scores["classifier_score"].notna()
& all_scores["percentile"].notna()
]
all_scores["md5"] = all_scores["md5"].astype(str)
all_scores["classifier"] = all_scores["classifier"].astype(str)
all_scores = all_scores.drop_duplicates(subset=["md5", "classifier"], keep="last")
all_scores = all_scores.sort_values(["md5", "classifier"], kind="mergesort")
all_scores_by_md5 = {
str(md5): [
{
"classifier": str(r.classifier),
"classifier_score": float(r.classifier_score),
"percentile": int(r.percentile),
}
for r in gdf.itertuples(index=False)
]
for md5, gdf in all_scores.groupby("md5", sort=False)
}
page_meta: list[dict[str, str | int]] = []
for row in page_df.to_dict("records"):
url = _row_image_url(row)
assert url is not None
post_id = int(row["id"])
md5 = str(row["md5"])
classifier_score = float(row["classifier_score"])
percentile = int(row["percentile"])
all_classifier_rows = all_scores_by_md5.get(md5)
assert all_classifier_rows, f"Missing classifier rows for md5: {md5}"
page_meta.append({"id": post_id, "md5": md5, "url": url, "classifier_score": classifier_score, "percentile": percentile, "all_classifier_rows": all_classifier_rows})
next_offset = offset + len(page_meta)
has_more = next_offset < len(filtered)
return page_meta, next_offset, has_more, len(filtered)
def build_results_data(
explorer_df: pd.DataFrame,
all_explorer_df: pd.DataFrame,
rating_pref: str,
sort_mode: str,
classifier_name: str,
) -> tuple[str, Figure, dict[str, list[float] | int], list[tuple[str, str]], list[dict[str, str | int]], int, dict]:
page_meta, next_offset, has_more, total = _build_page_meta(explorer_df, all_explorer_df, sort_mode, offset=0)
summary = f"Showing {total} images from validation_set.parquet joined with classifier_scores.parquet (rating: {rating_pref}, classifier: {classifier_name}, sort: {sort_mode})."
distribution_data = _distribution_data(explorer_df)
score_distribution_plot = _classifier_score_distribution_plot(distribution_data)
return summary, score_distribution_plot, distribution_data, _gallery_items(page_meta), page_meta, next_offset, gr.update(visible=has_more)
def load_more_results(
explorer_df: pd.DataFrame,
all_explorer_df: pd.DataFrame,
sort_mode: str,
offset: int,
):
page_meta, next_offset, has_more, _total = _build_page_meta(explorer_df, all_explorer_df, sort_mode, offset=int(offset))
return _gallery_items(page_meta), page_meta, next_offset, gr.update(visible=has_more)
def on_gallery_select(
evt: gr.SelectData,
meta: list[dict],
distribution_data: dict[str, list[float] | int],
) -> tuple[str, Figure]:
index = evt.index[0] if isinstance(evt.index, tuple) else evt.index
if not isinstance(index, int) or index < 0 or index >= len(meta):
return "No image selected.", _classifier_score_distribution_plot(distribution_data)
selected = meta[index]
post_id = int(selected["id"])
md5 = str(selected["md5"])
all_classifier_rows = selected.get("all_classifier_rows")
assert isinstance(all_classifier_rows, list) and all_classifier_rows, f"Missing classifier rows for md5: {md5}"
classifier_score = float(selected["classifier_score"])
rows_html = [f"<div>MD5: {html.escape(md5)} | ID: {post_id}</div>"]
for row in all_classifier_rows:
classifier_name = str(row["classifier"])
score = float(row["classifier_score"])
percentile = int(row["percentile"])
pct = max(0, min(100, percentile))
hue = int(round((pct / 100.0) * 120.0))
fill_color = f"hsl({hue}, 78%, 42%)"
bar_html = (
"<span style='display:inline-block;vertical-align:middle;width:72px;height:8px;"
"border-radius:999px;background:#e5e7eb;overflow:hidden;margin-left:6px;'>"
f"<span style='display:block;height:100%;width:{pct}%;background:{fill_color};'></span>"
"</span>"
)
rows_html.append(
"<div style='display:flex;align-items:center;column-gap:8px;'>"
f"<span style='display:inline-block;width:160px;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;'>{html.escape(classifier_name)}</span>"
f"<span style='display:inline-block;width:132px;'>Score {score:.4f}</span>"
f"<span style='display:inline-block;width:84px;'>Percentile {percentile:>3}</span>"
f"{bar_html}"
"</div>"
)
rows_html.append(f"<div>https://e621.net/posts/{post_id}</div>")
info = "".join(rows_html)
return info, _classifier_score_distribution_plot(distribution_data, selected_score=classifier_score)
def add_results_tab(pool_df: pd.DataFrame):
with gr.Tab("Explorer"):
results_summary_md = gr.Markdown()
results_gallery = gr.Gallery(
label="Category Mosaic",
columns=[6],
object_fit="contain",
preview=True,
height="auto",
elem_id="results-gallery",
)
results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more")
selected_image_md = gr.Markdown("Click an image to reveal its ID and link.")
results_score_distribution_plot = gr.Plot(label="Classifier score distribution")
results_rating_dd = gr.Dropdown(
choices=["safe", "all"],
value="safe",
label="Rating",
)
results_sort_dd = gr.Dropdown(
choices=SORT_MODES,
value="Default",
label="Sort",
elem_id="results-sort-mode",
)
results_classifier_dd = gr.Dropdown(
choices=list(ALLOWED_CLASSIFIER_FILTERS),
value=ALLOWED_CLASSIFIER_FILTERS[0],
label="Classifier",
elem_id="results-classifier",
)
results_distribution_state = gr.State({"bin_edges": [], "counts": [], "total": 0})
results_page_meta_state = gr.State([])
results_page_offset_state = gr.State(0)
return (
results_summary_md,
results_rating_dd,
results_sort_dd,
results_classifier_dd,
results_score_distribution_plot,
results_distribution_state,
results_gallery,
results_load_more_btn,
selected_image_md,
results_page_meta_state,
results_page_offset_state,
)