import gradio as gr
from PIL import Image
from typing import Dict, Tuple
import re
# 🔧 Your model lives here.
# Implement load() and predict(image) in model.py.
import model # noqa: F401
# --------- Glue code (kept intentionally tiny) ---------
def _format_outputs(scores: Dict[str, float], neighbors: list, threshold: float):
# sort by score desc and apply threshold
filtered = {k: v for k, v in sorted(scores.items(), key=lambda kv: kv[1], reverse=True) if v >= threshold}
tag_string = ", ".join(filtered.keys())
return tag_string, filtered, "\n".join(neighbors)
def infer(image: Image.Image):
if image is None:
return "", "" # (tag_panel_md, neighbors_md)
threshold = 0.01 # fixed cutoff
# Lazy-load if needed
if not getattr(model, "_READY", False):
try:
model.load()
except Exception as e:
print("model.load() during infer failed:", e)
# Predict
scores, neighbors, counts = model.predict(image)
# Sort & threshold
sorted_scores = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
filtered = [(k, float(v)) for k, v in sorted_scores if v >= threshold]
# ---------- Style Tags: HTML table (link | % right-aligned) ----------
if filtered:
rows = []
for tag, val in filtered:
pct = int(round(val * 100))
tag_q = tag.replace(" ", "_")
url = f"https://e621.net/posts?tags=order%3Afavcount+-animated+{tag_q}"
rows.append(
f"
"
f"| {tag} | "
f"{pct}% | "
f"
"
)
tag_panel_md = ""
else:
tag_panel_md = "_(no tags)_"
# ---------- Nearest Neighbors: Markdown list (no 'dist') ----------
# ---------- Nearest Neighbors: HTML table (ID | Styles | sim) ----------
rows = []
for item in neighbors:
if isinstance(item, dict):
fname = str(item.get("filename", ""))
sim = item.get("similarity", None)
styles = item.get("styles", [])
else:
fname = str(item)
sim = None
styles = []
# numeric ID (strip ".png", etc.); link to e621 if we find one
m = re.search(r"(\d+)", fname)
post_id = m.group(1) if m else fname
id_cell = (
f"{post_id}"
if m else post_id
)
styles_cell = ", ".join(styles)
sim_cell = f"{sim:.3f}" if sim is not None else ""
rows.append(
f""
f"| {id_cell} | "
f"{styles_cell} | "
f"{sim_cell} | "
f"
"
)
if rows:
neighbors_md = (
""
""
"| ID | "
"Styles | "
"sim | "
"
"
"" + "".join(rows) + "
"
)
else:
neighbors_md = "_(neighbors unavailable)_"
return tag_panel_md, neighbors_md
def clear_outputs():
return "", ""
custom_css = '''
#image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; }
#image_container img { object-fit: contain !important; }
/* card look for right-side panels */
.custom-card {
background: rgba(255,255,255,0.05); /* lighter than dark bg */
border: 1px solid rgba(255,255,255,0.14);
border-radius: 12px;
padding: 12px 14px;
}
.custom-card .prose { margin: 0; } /* tighter Markdown spacing */
.custom-card h3 { margin-top: 0; } /* keep section title snug */
.custom-card:hover { box-shadow: 0 6px 20px rgba(0,0,0,0.25); }
.nn-table { width: 100%; border-collapse: collapse; }
.nn-table th, .nn-table td { padding: 4px 8px; vertical-align: middle; }
.nn-table th { text-align: left; font-weight: 600; }
.nn-table .nn-id { width: 1%; white-space: nowrap; }
.nn-table .nn-sim { text-align: right; width: 1%; white-space: nowrap; }
.tag-table { width: 100%; border-collapse: collapse; }
.tag-table td { padding: 4px 8px; vertical-align: middle; }
.tag-table .tag-name { text-align: left; }
.tag-table .tag-pct { text-align: right; width: 1%; white-space: nowrap; }
'''
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(label="Drop an image here", sources=["upload", "clipboard"],
type="pil", show_label=False, elem_id="image_container")
# NEW: one right-side column that contains both cards stacked
with gr.Column():
with gr.Column(elem_classes=["custom-card"]):
gr.Markdown("### Style Tags")
tag_panel = gr.Markdown()
with gr.Column(elem_classes=["custom-card"]):
gr.Markdown("### Nearest Neighbors")
neighbors_text = gr.Markdown()
image.upload(fn=infer, inputs=[image], outputs=[tag_panel, neighbors_text], show_progress="minimal")
image.clear(fn=clear_outputs, inputs=[], outputs=[tag_panel, neighbors_text])
gr.Markdown("""
---
### Purpose
StyleSquirrel is designed for **style exploration**, not artist identification.
It **may not report the image's artist**, even when that artist exists in the dataset.
Use it to explore images with similar **colors, structures, textures, and visual motifs**, not as an attribution tool.
### Instructions
- Drop an image in the box on the left.
- The **Style Tags** panel reports tags that are stylistically similar to the query image.
- The **Nearest Neighbors** panel shows e621 images that are stylistically similar to the query image.
### Notes
- Links go to e621.net and may not be safe for work.
- I tried to isolate style from topic and was only partly successful, so many reported tags/images might be topically rather than stylistically similar.
- The similarity metric is currently a bit naive, causing irregularities like **simple_background** being over-reported.
- This tool has a very small training set and does **not** work well as a style classifier.
""")
if __name__ == "__main__":
# Load your model once at start-up (optional; define as no-op in model.py for now)
try:
model.load()
except Exception as e:
print("Model load() raised (ok during skeleton dev):", e)
demo.launch(css=custom_css)