Spaces:
Running
Running
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import html | |
| from matplotlib.figure import Figure | |
| PAGE_SIZE = 30 | |
| SORT_MODES = ["Default", "Rating: Low to High", "Rating: High to Low"] | |
| ALLOWED_CLASSIFIER_FILTERS = ("SiglipFinetune", "ScoringHead-v5", "TaggerExperiment") | |
| def _row_image_url(row) -> str | None: | |
| sample_url = row.get("sample_url") | |
| if isinstance(sample_url, str) and sample_url: | |
| return sample_url | |
| image_url = row.get("image_url") | |
| if isinstance(image_url, str) and image_url: | |
| return image_url | |
| return None | |
| def _gallery_items(meta: list[dict]) -> list[tuple[str, str]]: | |
| return [ | |
| ( | |
| str(item["url"]), | |
| f"Score: {float(item['classifier_score']):.4f} | Percentile: {int(item['percentile'])}", | |
| ) | |
| for item in meta | |
| ] | |
| def _distribution_data( | |
| explorer_df: pd.DataFrame, | |
| ) -> dict[str, list[float] | int]: | |
| valid_scores = [ | |
| float(score) | |
| for score in explorer_df["classifier_score"].tolist() | |
| if pd.notna(score) | |
| ] | |
| if not valid_scores: | |
| return {"bin_edges": [], "counts": [], "total": 0} | |
| counts, bin_edges = np.histogram(valid_scores, bins=40) | |
| return { | |
| "bin_edges": bin_edges.astype(float).tolist(), | |
| "counts": counts.astype(float).tolist(), | |
| "total": int(len(valid_scores)), | |
| } | |
| def _classifier_score_distribution_plot( | |
| distribution_data: dict[str, list[float] | int], | |
| selected_score: float | None = None, | |
| ) -> Figure: | |
| fig, ax = plt.subplots(figsize=(6, 2.2)) | |
| fig.patch.set_facecolor("#0f1117") | |
| ax.set_facecolor("#151922") | |
| bin_edges = np.asarray(distribution_data.get("bin_edges", []), dtype=float) | |
| counts = np.asarray(distribution_data.get("counts", []), dtype=float) | |
| if counts.size > 0 and bin_edges.size == counts.size + 1: | |
| widths = np.diff(bin_edges) | |
| ax.bar(bin_edges[:-1], counts, width=widths, align="edge", color="#3b82f6", alpha=0.9, edgecolor="#93c5fd", linewidth=0.35) | |
| ax.set_ylabel("Count", color="#e5e7eb") | |
| if selected_score is not None: | |
| ax.axvline(float(selected_score), color="#f97316", linewidth=2.0) | |
| else: | |
| ax.text(0.5, 0.5, "No classifier scores available.", ha="center", va="center", transform=ax.transAxes, color="#e5e7eb") | |
| ax.set_yticks([]) | |
| ax.set_title("Classifier Score Distribution", color="#f3f4f6") | |
| ax.set_xlabel("Classifier score", color="#e5e7eb") | |
| ax.tick_params(colors="#d1d5db") | |
| for spine in ax.spines.values(): | |
| spine.set_color("#4b5563") | |
| ax.grid(axis="y", color="#374151", alpha=0.4) | |
| fig.tight_layout() | |
| return fig | |
| def _build_page_meta( | |
| explorer_df: pd.DataFrame, | |
| all_explorer_df: pd.DataFrame, | |
| sort_mode: str, | |
| offset: int, | |
| ) -> tuple[list[dict[str, str | int]], int, bool, int]: | |
| filtered = explorer_df[["id", "md5", "sample_url", "image_url", "classifier_score", "percentile"]].copy() | |
| has_sample = filtered["sample_url"].notna() & (filtered["sample_url"] != "") | |
| has_image = filtered["image_url"].notna() & (filtered["image_url"] != "") | |
| filtered = filtered[has_sample | has_image] | |
| filtered = filtered[filtered["classifier_score"].notna() & filtered["percentile"].notna()] | |
| if sort_mode == "Rating: Low to High": | |
| filtered = filtered.sort_values("classifier_score", ascending=True, kind="mergesort") | |
| elif sort_mode == "Rating: High to Low": | |
| filtered = filtered.sort_values("classifier_score", ascending=False, kind="mergesort") | |
| page_df = filtered.iloc[offset:offset + PAGE_SIZE] | |
| all_scores = all_explorer_df[["md5", "classifier", "classifier_score", "percentile"]].copy() | |
| all_scores = all_scores[ | |
| all_scores["classifier"].notna() | |
| & all_scores["classifier_score"].notna() | |
| & all_scores["percentile"].notna() | |
| ] | |
| all_scores["md5"] = all_scores["md5"].astype(str) | |
| all_scores["classifier"] = all_scores["classifier"].astype(str) | |
| all_scores = all_scores.drop_duplicates(subset=["md5", "classifier"], keep="last") | |
| all_scores = all_scores.sort_values(["md5", "classifier"], kind="mergesort") | |
| all_scores_by_md5 = { | |
| str(md5): [ | |
| { | |
| "classifier": str(r.classifier), | |
| "classifier_score": float(r.classifier_score), | |
| "percentile": int(r.percentile), | |
| } | |
| for r in gdf.itertuples(index=False) | |
| ] | |
| for md5, gdf in all_scores.groupby("md5", sort=False) | |
| } | |
| page_meta: list[dict[str, str | int]] = [] | |
| for row in page_df.to_dict("records"): | |
| url = _row_image_url(row) | |
| assert url is not None | |
| post_id = int(row["id"]) | |
| md5 = str(row["md5"]) | |
| classifier_score = float(row["classifier_score"]) | |
| percentile = int(row["percentile"]) | |
| all_classifier_rows = all_scores_by_md5.get(md5) | |
| assert all_classifier_rows, f"Missing classifier rows for md5: {md5}" | |
| page_meta.append({"id": post_id, "md5": md5, "url": url, "classifier_score": classifier_score, "percentile": percentile, "all_classifier_rows": all_classifier_rows}) | |
| next_offset = offset + len(page_meta) | |
| has_more = next_offset < len(filtered) | |
| return page_meta, next_offset, has_more, len(filtered) | |
| def build_results_data( | |
| explorer_df: pd.DataFrame, | |
| all_explorer_df: pd.DataFrame, | |
| rating_pref: str, | |
| sort_mode: str, | |
| classifier_name: str, | |
| ) -> tuple[str, Figure, dict[str, list[float] | int], list[tuple[str, str]], list[dict[str, str | int]], int, dict]: | |
| page_meta, next_offset, has_more, total = _build_page_meta(explorer_df, all_explorer_df, sort_mode, offset=0) | |
| summary = f"Showing {total} images from validation_set.parquet joined with classifier_scores.parquet (rating: {rating_pref}, classifier: {classifier_name}, sort: {sort_mode})." | |
| distribution_data = _distribution_data(explorer_df) | |
| score_distribution_plot = _classifier_score_distribution_plot(distribution_data) | |
| return summary, score_distribution_plot, distribution_data, _gallery_items(page_meta), page_meta, next_offset, gr.update(visible=has_more) | |
| def load_more_results( | |
| explorer_df: pd.DataFrame, | |
| all_explorer_df: pd.DataFrame, | |
| sort_mode: str, | |
| offset: int, | |
| ): | |
| page_meta, next_offset, has_more, _total = _build_page_meta(explorer_df, all_explorer_df, sort_mode, offset=int(offset)) | |
| return _gallery_items(page_meta), page_meta, next_offset, gr.update(visible=has_more) | |
| def on_gallery_select( | |
| evt: gr.SelectData, | |
| meta: list[dict], | |
| distribution_data: dict[str, list[float] | int], | |
| ) -> tuple[str, Figure]: | |
| index = evt.index[0] if isinstance(evt.index, tuple) else evt.index | |
| if not isinstance(index, int) or index < 0 or index >= len(meta): | |
| return "No image selected.", _classifier_score_distribution_plot(distribution_data) | |
| selected = meta[index] | |
| post_id = int(selected["id"]) | |
| md5 = str(selected["md5"]) | |
| all_classifier_rows = selected.get("all_classifier_rows") | |
| assert isinstance(all_classifier_rows, list) and all_classifier_rows, f"Missing classifier rows for md5: {md5}" | |
| classifier_score = float(selected["classifier_score"]) | |
| rows_html = [f"<div>MD5: {html.escape(md5)} | ID: {post_id}</div>"] | |
| for row in all_classifier_rows: | |
| classifier_name = str(row["classifier"]) | |
| score = float(row["classifier_score"]) | |
| percentile = int(row["percentile"]) | |
| pct = max(0, min(100, percentile)) | |
| hue = int(round((pct / 100.0) * 120.0)) | |
| fill_color = f"hsl({hue}, 78%, 42%)" | |
| bar_html = ( | |
| "<span style='display:inline-block;vertical-align:middle;width:72px;height:8px;" | |
| "border-radius:999px;background:#e5e7eb;overflow:hidden;margin-left:6px;'>" | |
| f"<span style='display:block;height:100%;width:{pct}%;background:{fill_color};'></span>" | |
| "</span>" | |
| ) | |
| rows_html.append( | |
| "<div style='display:flex;align-items:center;column-gap:8px;'>" | |
| f"<span style='display:inline-block;width:160px;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;'>{html.escape(classifier_name)}</span>" | |
| f"<span style='display:inline-block;width:132px;'>Score {score:.4f}</span>" | |
| f"<span style='display:inline-block;width:84px;'>Percentile {percentile:>3}</span>" | |
| f"{bar_html}" | |
| "</div>" | |
| ) | |
| rows_html.append(f"<div>https://e621.net/posts/{post_id}</div>") | |
| info = "".join(rows_html) | |
| return info, _classifier_score_distribution_plot(distribution_data, selected_score=classifier_score) | |
| def add_results_tab(pool_df: pd.DataFrame): | |
| with gr.Tab("Explorer"): | |
| results_summary_md = gr.Markdown() | |
| results_gallery = gr.Gallery( | |
| label="Category Mosaic", | |
| columns=[6], | |
| object_fit="contain", | |
| preview=True, | |
| height="auto", | |
| elem_id="results-gallery", | |
| ) | |
| results_load_more_btn = gr.Button("Load more (ArrowDown)", elem_id="btn-results-load-more") | |
| selected_image_md = gr.Markdown("Click an image to reveal its ID and link.") | |
| results_score_distribution_plot = gr.Plot(label="Classifier score distribution") | |
| results_rating_dd = gr.Dropdown( | |
| choices=["safe", "all"], | |
| value="safe", | |
| label="Rating", | |
| ) | |
| results_sort_dd = gr.Dropdown( | |
| choices=SORT_MODES, | |
| value="Default", | |
| label="Sort", | |
| elem_id="results-sort-mode", | |
| ) | |
| results_classifier_dd = gr.Dropdown( | |
| choices=list(ALLOWED_CLASSIFIER_FILTERS), | |
| value=ALLOWED_CLASSIFIER_FILTERS[0], | |
| label="Classifier", | |
| elem_id="results-classifier", | |
| ) | |
| results_distribution_state = gr.State({"bin_edges": [], "counts": [], "total": 0}) | |
| results_page_meta_state = gr.State([]) | |
| results_page_offset_state = gr.State(0) | |
| return ( | |
| results_summary_md, | |
| results_rating_dd, | |
| results_sort_dd, | |
| results_classifier_dd, | |
| results_score_distribution_plot, | |
| results_distribution_state, | |
| results_gallery, | |
| results_load_more_btn, | |
| selected_image_md, | |
| results_page_meta_state, | |
| results_page_offset_state, | |
| ) | |