Spaces:
Running
Running
| import gradio as gr | |
| from PIL import Image | |
| from typing import Dict, Tuple | |
| import re | |
| # 🔧 Your model lives here. | |
| # Implement load() and predict(image) in model.py. | |
| import model # noqa: F401 | |
| # --------- Glue code (kept intentionally tiny) --------- | |
| def _format_outputs(scores: Dict[str, float], neighbors: list, threshold: float): | |
| # sort by score desc and apply threshold | |
| filtered = {k: v for k, v in sorted(scores.items(), key=lambda kv: kv[1], reverse=True) if v >= threshold} | |
| tag_string = ", ".join(filtered.keys()) | |
| return tag_string, filtered, "\n".join(neighbors) | |
| def infer(image: Image.Image): | |
| if image is None: | |
| return "", "" # (tag_panel_md, neighbors_md) | |
| threshold = 0.01 # fixed cutoff | |
| # Lazy-load if needed | |
| if not getattr(model, "_READY", False): | |
| try: | |
| model.load() | |
| except Exception as e: | |
| print("model.load() during infer failed:", e) | |
| # Predict | |
| scores, neighbors, counts = model.predict(image) | |
| # Sort & threshold | |
| sorted_scores = sorted(scores.items(), key=lambda kv: kv[1], reverse=True) | |
| filtered = [(k, float(v)) for k, v in sorted_scores if v >= threshold] | |
| # ---------- Style Tags: HTML table (link | % right-aligned) ---------- | |
| if filtered: | |
| rows = [] | |
| for tag, val in filtered: | |
| pct = int(round(val * 100)) | |
| tag_q = tag.replace(" ", "_") | |
| url = f"https://e621.net/posts?tags=order%3Afavcount+-animated+{tag_q}" | |
| rows.append( | |
| f"<tr>" | |
| f"<td class='tag-name'><a href='{url}' target='_blank' rel='noopener noreferrer'>{tag}</a></td>" | |
| f"<td class='tag-pct'>{pct}%</td>" | |
| f"</tr>" | |
| ) | |
| tag_panel_md = "<table class='tag-table'><tbody>" + "".join(rows) + "</tbody></table>" | |
| else: | |
| tag_panel_md = "_(no tags)_" | |
| # ---------- Nearest Neighbors: Markdown list (no 'dist') ---------- | |
| # ---------- Nearest Neighbors: HTML table (ID | Styles | sim) ---------- | |
| rows = [] | |
| for item in neighbors: | |
| if isinstance(item, dict): | |
| fname = str(item.get("filename", "")) | |
| sim = item.get("similarity", None) | |
| styles = item.get("styles", []) | |
| else: | |
| fname = str(item) | |
| sim = None | |
| styles = [] | |
| # numeric ID (strip ".png", etc.); link to e621 if we find one | |
| m = re.search(r"(\d+)", fname) | |
| post_id = m.group(1) if m else fname | |
| id_cell = ( | |
| f"<a href='https://e621.net/posts/{post_id}' target='_blank' rel='noopener noreferrer'>{post_id}</a>" | |
| if m else post_id | |
| ) | |
| styles_cell = ", ".join(styles) | |
| sim_cell = f"{sim:.3f}" if sim is not None else "" | |
| rows.append( | |
| f"<tr>" | |
| f"<td class='nn-id'>{id_cell}</td>" | |
| f"<td class='nn-styles'>{styles_cell}</td>" | |
| f"<td class='nn-sim'>{sim_cell}</td>" | |
| f"</tr>" | |
| ) | |
| if rows: | |
| neighbors_md = ( | |
| "<table class='nn-table'>" | |
| "<thead><tr>" | |
| "<th class='nn-id'>ID</th>" | |
| "<th class='nn-styles'>Styles</th>" | |
| "<th class='nn-sim'>sim</th>" | |
| "</tr></thead>" | |
| "<tbody>" + "".join(rows) + "</tbody></table>" | |
| ) | |
| else: | |
| neighbors_md = "_(neighbors unavailable)_" | |
| return tag_panel_md, neighbors_md | |
| def clear_outputs(): | |
| return "", "" | |
| custom_css = ''' | |
| #image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; } | |
| #image_container img { object-fit: contain !important; } | |
| /* card look for right-side panels */ | |
| .custom-card { | |
| background: rgba(255,255,255,0.05); /* lighter than dark bg */ | |
| border: 1px solid rgba(255,255,255,0.14); | |
| border-radius: 12px; | |
| padding: 12px 14px; | |
| } | |
| .custom-card .prose { margin: 0; } /* tighter Markdown spacing */ | |
| .custom-card h3 { margin-top: 0; } /* keep section title snug */ | |
| .custom-card:hover { box-shadow: 0 6px 20px rgba(0,0,0,0.25); } | |
| .nn-table { width: 100%; border-collapse: collapse; } | |
| .nn-table th, .nn-table td { padding: 4px 8px; vertical-align: middle; } | |
| .nn-table th { text-align: left; font-weight: 600; } | |
| .nn-table .nn-id { width: 1%; white-space: nowrap; } | |
| .nn-table .nn-sim { text-align: right; width: 1%; white-space: nowrap; } | |
| .tag-table { width: 100%; border-collapse: collapse; } | |
| .tag-table td { padding: 4px 8px; vertical-align: middle; } | |
| .tag-table .tag-name { text-align: left; } | |
| .tag-table .tag-pct { text-align: right; width: 1%; white-space: nowrap; } | |
| ''' | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label="Drop an image here", sources=["upload", "clipboard"], | |
| type="pil", show_label=False, elem_id="image_container") | |
| # NEW: one right-side column that contains both cards stacked | |
| with gr.Column(): | |
| with gr.Column(elem_classes=["custom-card"]): | |
| gr.Markdown("### Style Tags") | |
| tag_panel = gr.Markdown() | |
| with gr.Column(elem_classes=["custom-card"]): | |
| gr.Markdown("### Nearest Neighbors") | |
| neighbors_text = gr.Markdown() | |
| image.upload(fn=infer, inputs=[image], outputs=[tag_panel, neighbors_text], show_progress="minimal") | |
| image.clear(fn=clear_outputs, inputs=[], outputs=[tag_panel, neighbors_text]) | |
| gr.Markdown(""" | |
| --- | |
| ### Purpose | |
| StyleSquirrel is designed for **style exploration**, not artist identification. | |
| It **may not report the image's artist**, even when that artist exists in the dataset. | |
| Use it to explore images with similar **colors, structures, textures, and visual motifs**, not as an attribution tool. | |
| ### Instructions | |
| - Drop an image in the box on the left. | |
| - The **Style Tags** panel reports tags that are stylistically similar to the query image. | |
| - The **Nearest Neighbors** panel shows e621 images that are stylistically similar to the query image. | |
| ### Notes | |
| - Links go to e621.net and may not be safe for work. | |
| - I tried to isolate style from topic and was only partly successful, so many reported tags/images might be topically rather than stylistically similar. | |
| - The similarity metric is currently a bit naive, causing irregularities like **simple_background** being over-reported. | |
| - This tool has a very small training set and does **not** work well as a style classifier. | |
| """) | |
| if __name__ == "__main__": | |
| # Load your model once at start-up (optional; define as no-op in model.py for now) | |
| try: | |
| model.load() | |
| except Exception as e: | |
| print("Model load() raised (ok during skeleton dev):", e) | |
| demo.launch(css=custom_css) | |