File size: 6,843 Bytes
2cd64de
 
 
c5ae7ff
 
2cd64de
 
 
 
 
 
 
 
 
 
 
 
c5ae7ff
 
2cd64de
c5ae7ff
 
2cd64de
c5ae7ff
 
 
 
 
 
 
 
2cd64de
 
c5ae7ff
2cd64de
c5ae7ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd64de
c5ae7ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd64de
 
 
 
c5ae7ff
2cd64de
 
 
 
 
c5ae7ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd64de
 
c5ae7ff
e6c2a72
2cd64de
 
c5ae7ff
 
 
2cd64de
c5ae7ff
 
 
 
 
 
 
 
 
 
 
 
2cd64de
 
376e833
 
 
 
 
 
2cd64de
 
376e833
 
 
c5ae7ff
 
376e833
 
 
 
2cd64de
 
 
 
 
 
 
 
 
e6c2a72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import gradio as gr
from PIL import Image
from typing import Dict, Tuple
import re


# 🔧 Your model lives here.
# Implement load() and predict(image) in model.py.
import model  # noqa: F401

# --------- Glue code (kept intentionally tiny) ---------
def _format_outputs(scores: Dict[str, float], neighbors: list, threshold: float):
    # sort by score desc and apply threshold
    filtered = {k: v for k, v in sorted(scores.items(), key=lambda kv: kv[1], reverse=True) if v >= threshold}
    tag_string = ", ".join(filtered.keys())
    return tag_string, filtered, "\n".join(neighbors)


def infer(image: Image.Image):
    if image is None:
        return "", ""  # (tag_panel_md, neighbors_md)
    threshold = 0.01  # fixed cutoff

    # Lazy-load if needed
    if not getattr(model, "_READY", False):
        try:
            model.load()
        except Exception as e:
            print("model.load() during infer failed:", e)

    # Predict
    scores, neighbors, counts = model.predict(image)

    # Sort & threshold
    sorted_scores = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
    filtered = [(k, float(v)) for k, v in sorted_scores if v >= threshold]

    # ---------- Style Tags: HTML table (link | % right-aligned) ----------
    if filtered:
        rows = []
        for tag, val in filtered:
            pct = int(round(val * 100))
            tag_q = tag.replace(" ", "_")
            url = f"https://e621.net/posts?tags=order%3Afavcount+-animated+{tag_q}"
            rows.append(
                f"<tr>"
                f"<td class='tag-name'><a href='{url}' target='_blank' rel='noopener noreferrer'>{tag}</a></td>"
                f"<td class='tag-pct'>{pct}%</td>"
                f"</tr>"
            )
        tag_panel_md = "<table class='tag-table'><tbody>" + "".join(rows) + "</tbody></table>"
    else:
        tag_panel_md = "_(no tags)_"

    # ---------- Nearest Neighbors: Markdown list (no 'dist') ----------
        # ---------- Nearest Neighbors: HTML table (ID | Styles | sim) ----------
    rows = []
    for item in neighbors:
        if isinstance(item, dict):
            fname  = str(item.get("filename", ""))
            sim    = item.get("similarity", None)
            styles = item.get("styles", [])
        else:
            fname  = str(item)
            sim    = None
            styles = []

        # numeric ID (strip ".png", etc.); link to e621 if we find one
        m = re.search(r"(\d+)", fname)
        post_id = m.group(1) if m else fname
        id_cell = (
            f"<a href='https://e621.net/posts/{post_id}' target='_blank' rel='noopener noreferrer'>{post_id}</a>"
            if m else post_id
        )

        styles_cell = ", ".join(styles)
        sim_cell = f"{sim:.3f}" if sim is not None else ""

        rows.append(
            f"<tr>"
            f"<td class='nn-id'>{id_cell}</td>"
            f"<td class='nn-styles'>{styles_cell}</td>"
            f"<td class='nn-sim'>{sim_cell}</td>"
            f"</tr>"
        )

    if rows:
        neighbors_md = (
            "<table class='nn-table'>"
            "<thead><tr>"
            "<th class='nn-id'>ID</th>"
            "<th class='nn-styles'>Styles</th>"
            "<th class='nn-sim'>sim</th>"
            "</tr></thead>"
            "<tbody>" + "".join(rows) + "</tbody></table>"
        )
    else:
        neighbors_md = "_(neighbors unavailable)_"

    return tag_panel_md, neighbors_md



def clear_outputs():
    return "", ""


custom_css = '''
#image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; }
#image_container img { object-fit: contain !important; }

/* card look for right-side panels */
.custom-card {
  background: rgba(255,255,255,0.05);   /* lighter than dark bg */
  border: 1px solid rgba(255,255,255,0.14);
  border-radius: 12px;
  padding: 12px 14px;
}
.custom-card .prose { margin: 0; }       /* tighter Markdown spacing */
.custom-card h3 { margin-top: 0; }       /* keep section title snug */
.custom-card:hover { box-shadow: 0 6px 20px rgba(0,0,0,0.25); }
.nn-table { width: 100%; border-collapse: collapse; }
.nn-table th, .nn-table td { padding: 4px 8px; vertical-align: middle; }
.nn-table th { text-align: left; font-weight: 600; }
.nn-table .nn-id { width: 1%; white-space: nowrap; }
.nn-table .nn-sim { text-align: right; width: 1%; white-space: nowrap; }
.tag-table { width: 100%; border-collapse: collapse; }
.tag-table td { padding: 4px 8px; vertical-align: middle; }
.tag-table .tag-name { text-align: left; }
.tag-table .tag-pct { text-align: right; width: 1%; white-space: nowrap; }
'''


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Drop an image here", sources=["upload", "clipboard"],
                             type="pil", show_label=False, elem_id="image_container")
        # NEW: one right-side column that contains both cards stacked
        with gr.Column():
            with gr.Column(elem_classes=["custom-card"]):
                gr.Markdown("### Style Tags")
                tag_panel = gr.Markdown()

            with gr.Column(elem_classes=["custom-card"]):
                gr.Markdown("### Nearest Neighbors")
                neighbors_text = gr.Markdown()                                                      


    image.upload(fn=infer, inputs=[image], outputs=[tag_panel, neighbors_text], show_progress="minimal")
    image.clear(fn=clear_outputs, inputs=[], outputs=[tag_panel, neighbors_text])

    gr.Markdown("""
    ---

    ### Purpose
    StyleSquirrel is designed for **style exploration**, not artist identification.  
    It **may not report the image's artist**, even when that artist exists in the dataset.  
    Use it to explore images with similar **colors, structures, textures, and visual motifs**, not as an attribution tool.

    ### Instructions
    - Drop an image in the box on the left.
    - The **Style Tags** panel reports tags that are stylistically similar to the query image.
    - The **Nearest Neighbors** panel shows e621 images that are stylistically similar to the query image.

    ### Notes
    - Links go to e621.net and may not be safe for work.
    - I tried to isolate style from topic and was only partly successful, so many reported tags/images might be topically rather than stylistically similar.
    - The similarity metric is currently a bit naive, causing irregularities like **simple_background** being over-reported.
    - This tool has a very small training set and does **not** work well as a style classifier.

    """)


if __name__ == "__main__":
    # Load your model once at start-up (optional; define as no-op in model.py for now)
    try:
        model.load()
    except Exception as e:
        print("Model load() raised (ok during skeleton dev):", e)
    demo.launch(css=custom_css)