import gradio as gr from PIL import Image from typing import Dict, Tuple import re # 🔧 Your model lives here. # Implement load() and predict(image) in model.py. import model # noqa: F401 # --------- Glue code (kept intentionally tiny) --------- def _format_outputs(scores: Dict[str, float], neighbors: list, threshold: float): # sort by score desc and apply threshold filtered = {k: v for k, v in sorted(scores.items(), key=lambda kv: kv[1], reverse=True) if v >= threshold} tag_string = ", ".join(filtered.keys()) return tag_string, filtered, "\n".join(neighbors) def infer(image: Image.Image): if image is None: return "", "" # (tag_panel_md, neighbors_md) threshold = 0.01 # fixed cutoff # Lazy-load if needed if not getattr(model, "_READY", False): try: model.load() except Exception as e: print("model.load() during infer failed:", e) # Predict scores, neighbors, counts = model.predict(image) # Sort & threshold sorted_scores = sorted(scores.items(), key=lambda kv: kv[1], reverse=True) filtered = [(k, float(v)) for k, v in sorted_scores if v >= threshold] # ---------- Style Tags: HTML table (link | % right-aligned) ---------- if filtered: rows = [] for tag, val in filtered: pct = int(round(val * 100)) tag_q = tag.replace(" ", "_") url = f"https://e621.net/posts?tags=order%3Afavcount+-animated+{tag_q}" rows.append( f"" f"{tag}" f"{pct}%" f"" ) tag_panel_md = "" + "".join(rows) + "
" else: tag_panel_md = "_(no tags)_" # ---------- Nearest Neighbors: Markdown list (no 'dist') ---------- # ---------- Nearest Neighbors: HTML table (ID | Styles | sim) ---------- rows = [] for item in neighbors: if isinstance(item, dict): fname = str(item.get("filename", "")) sim = item.get("similarity", None) styles = item.get("styles", []) else: fname = str(item) sim = None styles = [] # numeric ID (strip ".png", etc.); link to e621 if we find one m = re.search(r"(\d+)", fname) post_id = m.group(1) if m else fname id_cell = ( f"{post_id}" if m else post_id ) styles_cell = ", ".join(styles) sim_cell = f"{sim:.3f}" if sim is not None else "" rows.append( f"" f"{id_cell}" f"{styles_cell}" f"{sim_cell}" f"" ) if rows: neighbors_md = ( "" "" "" "" "" "" "" + "".join(rows) + "
IDStylessim
" ) else: neighbors_md = "_(neighbors unavailable)_" return tag_panel_md, neighbors_md def clear_outputs(): return "", "" custom_css = ''' #image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; } #image_container img { object-fit: contain !important; } /* card look for right-side panels */ .custom-card { background: rgba(255,255,255,0.05); /* lighter than dark bg */ border: 1px solid rgba(255,255,255,0.14); border-radius: 12px; padding: 12px 14px; } .custom-card .prose { margin: 0; } /* tighter Markdown spacing */ .custom-card h3 { margin-top: 0; } /* keep section title snug */ .custom-card:hover { box-shadow: 0 6px 20px rgba(0,0,0,0.25); } .nn-table { width: 100%; border-collapse: collapse; } .nn-table th, .nn-table td { padding: 4px 8px; vertical-align: middle; } .nn-table th { text-align: left; font-weight: 600; } .nn-table .nn-id { width: 1%; white-space: nowrap; } .nn-table .nn-sim { text-align: right; width: 1%; white-space: nowrap; } .tag-table { width: 100%; border-collapse: collapse; } .tag-table td { padding: 4px 8px; vertical-align: middle; } .tag-table .tag-name { text-align: left; } .tag-table .tag-pct { text-align: right; width: 1%; white-space: nowrap; } ''' with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image(label="Drop an image here", sources=["upload", "clipboard"], type="pil", show_label=False, elem_id="image_container") # NEW: one right-side column that contains both cards stacked with gr.Column(): with gr.Column(elem_classes=["custom-card"]): gr.Markdown("### Style Tags") tag_panel = gr.Markdown() with gr.Column(elem_classes=["custom-card"]): gr.Markdown("### Nearest Neighbors") neighbors_text = gr.Markdown() image.upload(fn=infer, inputs=[image], outputs=[tag_panel, neighbors_text], show_progress="minimal") image.clear(fn=clear_outputs, inputs=[], outputs=[tag_panel, neighbors_text]) gr.Markdown(""" --- ### Purpose StyleSquirrel is designed for **style exploration**, not artist identification. It **may not report the image's artist**, even when that artist exists in the dataset. Use it to explore images with similar **colors, structures, textures, and visual motifs**, not as an attribution tool. ### Instructions - Drop an image in the box on the left. - The **Style Tags** panel reports tags that are stylistically similar to the query image. - The **Nearest Neighbors** panel shows e621 images that are stylistically similar to the query image. ### Notes - Links go to e621.net and may not be safe for work. - I tried to isolate style from topic and was only partly successful, so many reported tags/images might be topically rather than stylistically similar. - The similarity metric is currently a bit naive, causing irregularities like **simple_background** being over-reported. - This tool has a very small training set and does **not** work well as a style classifier. """) if __name__ == "__main__": # Load your model once at start-up (optional; define as no-op in model.py for now) try: model.load() except Exception as e: print("Model load() raised (ok during skeleton dev):", e) demo.launch(css=custom_css)