Spaces:
Running
Running
File size: 6,843 Bytes
2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de c5ae7ff e6c2a72 2cd64de c5ae7ff 2cd64de c5ae7ff 2cd64de 376e833 2cd64de 376e833 c5ae7ff 376e833 2cd64de e6c2a72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import gradio as gr
from PIL import Image
from typing import Dict, Tuple
import re
# 🔧 Your model lives here.
# Implement load() and predict(image) in model.py.
import model # noqa: F401
# --------- Glue code (kept intentionally tiny) ---------
def _format_outputs(scores: Dict[str, float], neighbors: list, threshold: float):
# sort by score desc and apply threshold
filtered = {k: v for k, v in sorted(scores.items(), key=lambda kv: kv[1], reverse=True) if v >= threshold}
tag_string = ", ".join(filtered.keys())
return tag_string, filtered, "\n".join(neighbors)
def infer(image: Image.Image):
if image is None:
return "", "" # (tag_panel_md, neighbors_md)
threshold = 0.01 # fixed cutoff
# Lazy-load if needed
if not getattr(model, "_READY", False):
try:
model.load()
except Exception as e:
print("model.load() during infer failed:", e)
# Predict
scores, neighbors, counts = model.predict(image)
# Sort & threshold
sorted_scores = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
filtered = [(k, float(v)) for k, v in sorted_scores if v >= threshold]
# ---------- Style Tags: HTML table (link | % right-aligned) ----------
if filtered:
rows = []
for tag, val in filtered:
pct = int(round(val * 100))
tag_q = tag.replace(" ", "_")
url = f"https://e621.net/posts?tags=order%3Afavcount+-animated+{tag_q}"
rows.append(
f"<tr>"
f"<td class='tag-name'><a href='{url}' target='_blank' rel='noopener noreferrer'>{tag}</a></td>"
f"<td class='tag-pct'>{pct}%</td>"
f"</tr>"
)
tag_panel_md = "<table class='tag-table'><tbody>" + "".join(rows) + "</tbody></table>"
else:
tag_panel_md = "_(no tags)_"
# ---------- Nearest Neighbors: Markdown list (no 'dist') ----------
# ---------- Nearest Neighbors: HTML table (ID | Styles | sim) ----------
rows = []
for item in neighbors:
if isinstance(item, dict):
fname = str(item.get("filename", ""))
sim = item.get("similarity", None)
styles = item.get("styles", [])
else:
fname = str(item)
sim = None
styles = []
# numeric ID (strip ".png", etc.); link to e621 if we find one
m = re.search(r"(\d+)", fname)
post_id = m.group(1) if m else fname
id_cell = (
f"<a href='https://e621.net/posts/{post_id}' target='_blank' rel='noopener noreferrer'>{post_id}</a>"
if m else post_id
)
styles_cell = ", ".join(styles)
sim_cell = f"{sim:.3f}" if sim is not None else ""
rows.append(
f"<tr>"
f"<td class='nn-id'>{id_cell}</td>"
f"<td class='nn-styles'>{styles_cell}</td>"
f"<td class='nn-sim'>{sim_cell}</td>"
f"</tr>"
)
if rows:
neighbors_md = (
"<table class='nn-table'>"
"<thead><tr>"
"<th class='nn-id'>ID</th>"
"<th class='nn-styles'>Styles</th>"
"<th class='nn-sim'>sim</th>"
"</tr></thead>"
"<tbody>" + "".join(rows) + "</tbody></table>"
)
else:
neighbors_md = "_(neighbors unavailable)_"
return tag_panel_md, neighbors_md
def clear_outputs():
return "", ""
custom_css = '''
#image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; }
#image_container img { object-fit: contain !important; }
/* card look for right-side panels */
.custom-card {
background: rgba(255,255,255,0.05); /* lighter than dark bg */
border: 1px solid rgba(255,255,255,0.14);
border-radius: 12px;
padding: 12px 14px;
}
.custom-card .prose { margin: 0; } /* tighter Markdown spacing */
.custom-card h3 { margin-top: 0; } /* keep section title snug */
.custom-card:hover { box-shadow: 0 6px 20px rgba(0,0,0,0.25); }
.nn-table { width: 100%; border-collapse: collapse; }
.nn-table th, .nn-table td { padding: 4px 8px; vertical-align: middle; }
.nn-table th { text-align: left; font-weight: 600; }
.nn-table .nn-id { width: 1%; white-space: nowrap; }
.nn-table .nn-sim { text-align: right; width: 1%; white-space: nowrap; }
.tag-table { width: 100%; border-collapse: collapse; }
.tag-table td { padding: 4px 8px; vertical-align: middle; }
.tag-table .tag-name { text-align: left; }
.tag-table .tag-pct { text-align: right; width: 1%; white-space: nowrap; }
'''
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(label="Drop an image here", sources=["upload", "clipboard"],
type="pil", show_label=False, elem_id="image_container")
# NEW: one right-side column that contains both cards stacked
with gr.Column():
with gr.Column(elem_classes=["custom-card"]):
gr.Markdown("### Style Tags")
tag_panel = gr.Markdown()
with gr.Column(elem_classes=["custom-card"]):
gr.Markdown("### Nearest Neighbors")
neighbors_text = gr.Markdown()
image.upload(fn=infer, inputs=[image], outputs=[tag_panel, neighbors_text], show_progress="minimal")
image.clear(fn=clear_outputs, inputs=[], outputs=[tag_panel, neighbors_text])
gr.Markdown("""
---
### Purpose
StyleSquirrel is designed for **style exploration**, not artist identification.
It **may not report the image's artist**, even when that artist exists in the dataset.
Use it to explore images with similar **colors, structures, textures, and visual motifs**, not as an attribution tool.
### Instructions
- Drop an image in the box on the left.
- The **Style Tags** panel reports tags that are stylistically similar to the query image.
- The **Nearest Neighbors** panel shows e621 images that are stylistically similar to the query image.
### Notes
- Links go to e621.net and may not be safe for work.
- I tried to isolate style from topic and was only partly successful, so many reported tags/images might be topically rather than stylistically similar.
- The similarity metric is currently a bit naive, causing irregularities like **simple_background** being over-reported.
- This tool has a very small training set and does **not** work well as a style classifier.
""")
if __name__ == "__main__":
# Load your model once at start-up (optional; define as no-op in model.py for now)
try:
model.load()
except Exception as e:
print("Model load() raised (ok during skeleton dev):", e)
demo.launch(css=custom_css)
|