import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import html from matplotlib.figure import Figure PAGE_SIZE = 30 SORT_MODES = ["Default", "Rating: Low to High", "Rating: High to Low"] ALLOWED_CLASSIFIER_FILTERS = ("SiglipFinetune", "ScoringHead-v5", "TaggerExperiment") def _row_image_url(row) -> str | None: sample_url = row.get("sample_url") if isinstance(sample_url, str) and sample_url: return sample_url image_url = row.get("image_url") if isinstance(image_url, str) and image_url: return image_url return None def _gallery_items(meta: list[dict]) -> list[tuple[str, str]]: return [ ( str(item["url"]), f"Score: {float(item['classifier_score']):.4f} | Percentile: {int(item['percentile'])}", ) for item in meta ] def _distribution_data( explorer_df: pd.DataFrame, ) -> dict[str, list[float] | int]: valid_scores = [ float(score) for score in explorer_df["classifier_score"].tolist() if pd.notna(score) ] if not valid_scores: return {"bin_edges": [], "counts": [], "total": 0} counts, bin_edges = np.histogram(valid_scores, bins=40) return { "bin_edges": bin_edges.astype(float).tolist(), "counts": counts.astype(float).tolist(), "total": int(len(valid_scores)), } def _classifier_score_distribution_plot( distribution_data: dict[str, list[float] | int], selected_score: float | None = None, ) -> Figure: fig, ax = plt.subplots(figsize=(6, 2.2)) fig.patch.set_facecolor("#0f1117") ax.set_facecolor("#151922") bin_edges = np.asarray(distribution_data.get("bin_edges", []), dtype=float) counts = np.asarray(distribution_data.get("counts", []), dtype=float) if counts.size > 0 and bin_edges.size == counts.size + 1: widths = np.diff(bin_edges) ax.bar(bin_edges[:-1], counts, width=widths, align="edge", color="#3b82f6", alpha=0.9, edgecolor="#93c5fd", linewidth=0.35) ax.set_ylabel("Count", color="#e5e7eb") if selected_score is not None: ax.axvline(float(selected_score), color="#f97316", linewidth=2.0) else: ax.text(0.5, 0.5, "No classifier scores available.", ha="center", va="center", transform=ax.transAxes, color="#e5e7eb") ax.set_yticks([]) ax.set_title("Classifier Score Distribution", color="#f3f4f6") ax.set_xlabel("Classifier score", color="#e5e7eb") ax.tick_params(colors="#d1d5db") for spine in ax.spines.values(): spine.set_color("#4b5563") ax.grid(axis="y", color="#374151", alpha=0.4) fig.tight_layout() return fig def _build_page_meta( explorer_df: pd.DataFrame, all_explorer_df: pd.DataFrame, sort_mode: str, offset: int, ) -> tuple[list[dict[str, str | int]], int, bool, int]: filtered = explorer_df[["id", "md5", "sample_url", "image_url", "classifier_score", "percentile"]].copy() has_sample = filtered["sample_url"].notna() & (filtered["sample_url"] != "") has_image = filtered["image_url"].notna() & (filtered["image_url"] != "") filtered = filtered[has_sample | has_image] filtered = filtered[filtered["classifier_score"].notna() & filtered["percentile"].notna()] if sort_mode == "Rating: Low to High": filtered = filtered.sort_values("classifier_score", ascending=True, kind="mergesort") elif sort_mode == "Rating: High to Low": filtered = filtered.sort_values("classifier_score", ascending=False, kind="mergesort") page_df = filtered.iloc[offset:offset + PAGE_SIZE] all_scores = all_explorer_df[["md5", "classifier", "classifier_score", "percentile"]].copy() all_scores = all_scores[ all_scores["classifier"].notna() & all_scores["classifier_score"].notna() & all_scores["percentile"].notna() ] all_scores["md5"] = all_scores["md5"].astype(str) all_scores["classifier"] = all_scores["classifier"].astype(str) all_scores = all_scores.drop_duplicates(subset=["md5", "classifier"], keep="last") all_scores = all_scores.sort_values(["md5", "classifier"], kind="mergesort") all_scores_by_md5 = { str(md5): [ { "classifier": str(r.classifier), "classifier_score": float(r.classifier_score), "percentile": int(r.percentile), } for r in gdf.itertuples(index=False) ] for md5, gdf in all_scores.groupby("md5", sort=False) } page_meta: list[dict[str, str | int]] = [] for row in page_df.to_dict("records"): url = _row_image_url(row) assert url is not None post_id = int(row["id"]) md5 = str(row["md5"]) classifier_score = float(row["classifier_score"]) percentile = int(row["percentile"]) all_classifier_rows = all_scores_by_md5.get(md5) assert all_classifier_rows, f"Missing classifier rows for md5: {md5}" page_meta.append({"id": post_id, "md5": md5, "url": url, "classifier_score": classifier_score, "percentile": percentile, "all_classifier_rows": all_classifier_rows}) next_offset = offset + len(page_meta) has_more = next_offset < len(filtered) return page_meta, next_offset, has_more, len(filtered) def build_results_data( explorer_df: pd.DataFrame, all_explorer_df: pd.DataFrame, rating_pref: str, sort_mode: str, classifier_name: str, ) -> tuple[str, Figure, dict[str, list[float] | int], list[tuple[str, str]], list[dict[str, str | int]], int, dict]: page_meta, next_offset, has_more, total = _build_page_meta(explorer_df, all_explorer_df, sort_mode, offset=0) summary = f"Showing {total} images from validation_set.parquet joined with classifier_scores.parquet (rating: {rating_pref}, classifier: {classifier_name}, sort: {sort_mode})." distribution_data = _distribution_data(explorer_df) score_distribution_plot = _classifier_score_distribution_plot(distribution_data) return summary, score_distribution_plot, distribution_data, _gallery_items(page_meta), page_meta, next_offset, gr.update(visible=has_more) def load_more_results( explorer_df: pd.DataFrame, all_explorer_df: pd.DataFrame, sort_mode: str, offset: int, ): page_meta, next_offset, has_more, _total = _build_page_meta(explorer_df, all_explorer_df, sort_mode, offset=int(offset)) return _gallery_items(page_meta), page_meta, next_offset, gr.update(visible=has_more) def on_gallery_select( evt: gr.SelectData, meta: list[dict], distribution_data: dict[str, list[float] | int], ) -> tuple[str, Figure]: index = evt.index[0] if isinstance(evt.index, tuple) else evt.index if not isinstance(index, int) or index < 0 or index >= len(meta): return "No image selected.", _classifier_score_distribution_plot(distribution_data) selected = meta[index] post_id = int(selected["id"]) md5 = str(selected["md5"]) all_classifier_rows = selected.get("all_classifier_rows") assert isinstance(all_classifier_rows, list) and all_classifier_rows, f"Missing classifier rows for md5: {md5}" classifier_score = float(selected["classifier_score"]) rows_html = [f"