Food Desert commited on
Commit
c5ae7ff
·
1 Parent(s): 11f3316

UI polish: cards + tables; clickable tags; cleanup

Browse files
Files changed (2) hide show
  1. README.md +14 -0
  2. app.py +122 -34
README.md CHANGED
@@ -16,3 +16,17 @@ Drop or paste an image → get style tags and see nearest training images.
16
  - Trained projector maps to “style space”
17
  - FAISS finds nearest training images
18
  - We tally their style tags and normalize to scores in [0,1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  - Trained projector maps to “style space”
17
  - FAISS finds nearest training images
18
  - We tally their style tags and normalize to scores in [0,1]
19
+
20
+ ---
21
+
22
+ # StyleSquirrel
23
+
24
+ Nearest-neighbor style tagger demo.
25
+ Drag an image to see predicted style tags and similar training images.
26
+
27
+ ⚠️ **Note on large model files**
28
+ If you clone this repository locally, you must pull the big model and FAISS index files with Git LFS before running:
29
+
30
+ ```bash
31
+ git lfs install
32
+ git lfs pull
app.py CHANGED
@@ -1,7 +1,8 @@
1
-
2
  import gradio as gr
3
  from PIL import Image
4
  from typing import Dict, Tuple
 
 
5
 
6
  # 🔧 Your model lives here.
7
  # Implement load() and predict(image) in model.py.
@@ -14,65 +15,152 @@ def _format_outputs(scores: Dict[str, float], neighbors: list, threshold: float)
14
  tag_string = ", ".join(filtered.keys())
15
  return tag_string, filtered, "\n".join(neighbors)
16
 
17
- def infer(image: Image.Image, threshold: float):
 
18
  if image is None:
19
- return "", {}, ""
 
20
 
21
- # model.predict now returns (scores_norm, neighbors, counts_raw)
 
 
 
 
 
 
 
22
  scores, neighbors, counts = model.predict(image)
23
 
24
- # Sort and threshold the display dict
25
  sorted_scores = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
26
- filtered = {k: float(v) for k, v in sorted_scores if v >= threshold}
27
- tag_text = ", ".join(filtered.keys())
28
-
29
- # Pretty-print neighbors. Each neighbor dict has:
30
- # { "filename": str, "similarity": float, "distance": float, "styles": [str, ...] }
31
- lines = []
32
- for i, d in enumerate(neighbors, 1):
33
- styles_str = ", ".join(d.get("styles", []))
34
- sim = d.get("similarity", None)
35
- dist = d.get("distance", None)
36
- if sim is not None and dist is not None:
37
- lines.append(f"{i}. {d['filename']} sim={sim:.3f} dist={dist:.3f} styles=[{styles_str}]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  else:
39
- # (just in case similarity/distance are not present)
40
- lines.append(f"{i}. {d['filename']} styles=[{styles_str}]")
41
- neighbors_text = "\n".join(lines) if lines else "(neighbors unavailable)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Return three outputs: tag text, scores dict, and neighbors textbox text
44
- return tag_text, filtered, neighbors_text
45
 
46
 
47
  def clear_outputs():
48
- return "", {}, ""
49
 
50
 
51
  custom_css = '''
52
  #image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; }
53
  #image_container img { object-fit: contain !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  '''
55
 
 
56
  with gr.Blocks(css=custom_css) as demo:
57
  gr.Markdown("## Style Tagger — Skeleton (local dev first)")
58
  with gr.Row():
59
  with gr.Column():
60
- image = gr.Image(label="Drop an image here", sources=["upload", "clipboard"], type="pil", show_label=False, elem_id="image_container")
 
 
61
  with gr.Column():
62
- threshold = gr.Slider(0.0, 1.0, value=0.01, step=0.01, label="Confidence threshold")
63
- tag_text = gr.Textbox(label="Style tags (comma-separated)")
64
- tag_scores = gr.Label(label="Scores", num_top_classes=250, show_label=False)
65
- neighbors_text = gr.Textbox(label="Nearest training images", lines=8, interactive=False)
66
-
67
- image.upload(fn=infer, inputs=[image, threshold], outputs=[tag_text, tag_scores, neighbors_text], show_progress="minimal")
68
- image.clear(fn=clear_outputs, inputs=[], outputs=[tag_text, tag_scores, neighbors_text])
69
- threshold.input(fn=infer, inputs=[image, threshold], outputs=[tag_text, tag_scores, neighbors_text], show_progress="hidden")
 
 
 
 
70
  gr.Markdown("""
71
  ---
72
  ### Instructions
73
  - Drop an image in the box on the left.
74
- - Tags that are stylistically similar are returned, along with some statistics about them.
75
- - I tried to isolate style from topic and was only partly successful. So many reported tags might be topically rather than stylistically similar.
 
 
 
 
76
  """)
77
 
78
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  from typing import Dict, Tuple
4
+ import re
5
+
6
 
7
  # 🔧 Your model lives here.
8
  # Implement load() and predict(image) in model.py.
 
15
  tag_string = ", ".join(filtered.keys())
16
  return tag_string, filtered, "\n".join(neighbors)
17
 
18
+
19
+ def infer(image: Image.Image):
20
  if image is None:
21
+ return "", "" # (tag_panel_md, neighbors_md)
22
+ threshold = 0.01 # fixed cutoff
23
 
24
+ # Lazy-load if needed
25
+ if not getattr(model, "_READY", False):
26
+ try:
27
+ model.load()
28
+ except Exception as e:
29
+ print("model.load() during infer failed:", e)
30
+
31
+ # Predict
32
  scores, neighbors, counts = model.predict(image)
33
 
34
+ # Sort & threshold
35
  sorted_scores = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
36
+ filtered = [(k, float(v)) for k, v in sorted_scores if v >= threshold]
37
+
38
+ # ---------- Style Tags: HTML table (link | % right-aligned) ----------
39
+ if filtered:
40
+ rows = []
41
+ for tag, val in filtered:
42
+ pct = int(round(val * 100))
43
+ tag_q = tag.replace(" ", "_")
44
+ url = f"https://e621.net/posts?tags=order%3Afavcount+-animated+{tag_q}"
45
+ rows.append(
46
+ f"<tr>"
47
+ f"<td class='tag-name'><a href='{url}' target='_blank' rel='noopener noreferrer'>{tag}</a></td>"
48
+ f"<td class='tag-pct'>{pct}%</td>"
49
+ f"</tr>"
50
+ )
51
+ tag_panel_md = "<table class='tag-table'><tbody>" + "".join(rows) + "</tbody></table>"
52
+ else:
53
+ tag_panel_md = "_(no tags)_"
54
+
55
+ # ---------- Nearest Neighbors: Markdown list (no 'dist') ----------
56
+ # ---------- Nearest Neighbors: HTML table (ID | Styles | sim) ----------
57
+ rows = []
58
+ for item in neighbors:
59
+ if isinstance(item, dict):
60
+ fname = str(item.get("filename", ""))
61
+ sim = item.get("similarity", None)
62
+ styles = item.get("styles", [])
63
  else:
64
+ fname = str(item)
65
+ sim = None
66
+ styles = []
67
+
68
+ # numeric ID (strip ".png", etc.); link to e621 if we find one
69
+ m = re.search(r"(\d+)", fname)
70
+ post_id = m.group(1) if m else fname
71
+ id_cell = (
72
+ f"<a href='https://e621.net/posts/{post_id}' target='_blank' rel='noopener noreferrer'>{post_id}</a>"
73
+ if m else post_id
74
+ )
75
+
76
+ styles_cell = ", ".join(styles)
77
+ sim_cell = f"{sim:.3f}" if sim is not None else ""
78
+
79
+ rows.append(
80
+ f"<tr>"
81
+ f"<td class='nn-id'>{id_cell}</td>"
82
+ f"<td class='nn-styles'>{styles_cell}</td>"
83
+ f"<td class='nn-sim'>{sim_cell}</td>"
84
+ f"</tr>"
85
+ )
86
+
87
+ if rows:
88
+ neighbors_md = (
89
+ "<table class='nn-table'>"
90
+ "<thead><tr>"
91
+ "<th class='nn-id'>ID</th>"
92
+ "<th class='nn-styles'>Styles</th>"
93
+ "<th class='nn-sim'>sim</th>"
94
+ "</tr></thead>"
95
+ "<tbody>" + "".join(rows) + "</tbody></table>"
96
+ )
97
+ else:
98
+ neighbors_md = "_(neighbors unavailable)_"
99
+
100
+ return tag_panel_md, neighbors_md
101
 
 
 
102
 
103
 
104
  def clear_outputs():
105
+ return "", ""
106
 
107
 
108
  custom_css = '''
109
  #image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; }
110
  #image_container img { object-fit: contain !important; }
111
+
112
+ /* card look for right-side panels */
113
+ .custom-card {
114
+ background: rgba(255,255,255,0.05); /* lighter than dark bg */
115
+ border: 1px solid rgba(255,255,255,0.14);
116
+ border-radius: 12px;
117
+ padding: 12px 14px;
118
+ }
119
+ .custom-card .prose { margin: 0; } /* tighter Markdown spacing */
120
+ .custom-card h3 { margin-top: 0; } /* keep section title snug */
121
+ .custom-card:hover { box-shadow: 0 6px 20px rgba(0,0,0,0.25); }
122
+ .nn-table { width: 100%; border-collapse: collapse; }
123
+ .nn-table th, .nn-table td { padding: 4px 8px; vertical-align: middle; }
124
+ .nn-table th { text-align: left; font-weight: 600; }
125
+ .nn-table .nn-id { width: 1%; white-space: nowrap; }
126
+ .nn-table .nn-sim { text-align: right; width: 1%; white-space: nowrap; }
127
+ .tag-table { width: 100%; border-collapse: collapse; }
128
+ .tag-table td { padding: 4px 8px; vertical-align: middle; }
129
+ .tag-table .tag-name { text-align: left; }
130
+ .tag-table .tag-pct { text-align: right; width: 1%; white-space: nowrap; }
131
  '''
132
 
133
+
134
  with gr.Blocks(css=custom_css) as demo:
135
  gr.Markdown("## Style Tagger — Skeleton (local dev first)")
136
  with gr.Row():
137
  with gr.Column():
138
+ image = gr.Image(label="Drop an image here", sources=["upload", "clipboard"],
139
+ type="pil", show_label=False, elem_id="image_container")
140
+ # NEW: one right-side column that contains both cards stacked
141
  with gr.Column():
142
+ with gr.Column(elem_classes=["custom-card"]):
143
+ gr.Markdown("### Style Tags")
144
+ tag_panel = gr.Markdown()
145
+
146
+ with gr.Column(elem_classes=["custom-card"]):
147
+ gr.Markdown("### Nearest Neighbors")
148
+ neighbors_text = gr.Markdown()
149
+
150
+
151
+ image.upload(fn=infer, inputs=[image], outputs=[tag_panel, neighbors_text], show_progress="minimal")
152
+ image.clear(fn=clear_outputs, inputs=[], outputs=[tag_panel, neighbors_text])
153
+
154
  gr.Markdown("""
155
  ---
156
  ### Instructions
157
  - Drop an image in the box on the left.
158
+ - The "Style Tags" panel reports on tags that are stylistically similar to the query image.
159
+ - The "Nearest Neighbors" panel reports on e621 images that are stylistically similar to the query image.
160
+ ### Notes
161
+ - Links go to e621.net and may not be safe for work.
162
+ - I tried to isolate style from topic and was only partly successful. So many reported tags and images might be topically rather than stylistically similar.
163
+ - The similarity metric is currently a bit naive, leading to irregularities like the "simple_background" tag being overreported due to its frequency.
164
  """)
165
 
166