Your Name commited on
Commit
37e5bdb
·
1 Parent(s): a70668e
Files changed (8) hide show
  1. .gitattributes +0 -35
  2. .gitignore +2 -0
  3. README.md +33 -8
  4. _orig.py +109 -0
  5. app.py +392 -0
  6. app_v1.py +382 -0
  7. prompt.txt +3 -0
  8. requirements.txt +6 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .devcontainer
2
+ .vscode
README.md CHANGED
@@ -1,14 +1,39 @@
1
  ---
2
- title: Image Caption Trimmer
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.10.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: shorten a text by dropping unimportant words
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Word Importance Evaluator
3
+ emoji: 🔬
4
+ colorFrom: yellow
5
+ colorTo: teal
6
  sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
 
11
  ---
12
 
13
+ # Word Importance Evaluator
14
+
15
+ Drop-one embedding analysis using `sentence-transformers/static-retrieval-mrl-en-v1`.
16
+
17
+ Each word's importance score = the semantic distance introduced by omitting that word
18
+ from the prompt (higher = more critical to the meaning).
19
+
20
+ ## Features
21
+
22
+ - **Importance bar chart** — horizontal bars coloured by a hot→cold colormap, with a draggable threshold line
23
+ - **Distribution per word** — violin-style sampled spread showing where each word's importance would land under paraphrase jitter
24
+ - **Threshold filter** — highlighted HTML output and summary of words above the cutoff
25
+ - **Multi-line prompt support** — all lines are concatenated into a single word list
26
+
27
+ ## Usage
28
+
29
+ 1. Paste a prompt (e.g. a Stable Diffusion caption)
30
+ 2. Adjust the importance threshold (default 0.30)
31
+ 3. Adjust distribution sample count if desired
32
+ 4. Click **Analyse →**
33
+
34
+ ## Files
35
+
36
+ | File | Purpose |
37
+ |---|---|
38
+ | `app.py` | Full Gradio Space — core evaluator code is unchanged |
39
+ | `requirements.txt` | Python dependencies |
_orig.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+
8
+ # %%
9
+
10
+
11
+ def create_splits(p):
12
+ # Create prompts with each word omitted
13
+ words = p.split()
14
+ omit_prompts = [
15
+ " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words))
16
+ ]
17
+ return words, omit_prompts
18
+
19
+
20
+ # %%
21
+ from abc import ABC, abstractmethod
22
+
23
+
24
+ class IE(ABC):
25
+ @abstractmethod
26
+ def get_word_importance_chunked(self, PROMPT):
27
+ pass
28
+
29
+
30
+ class ImportanceEvaluatorStatic(IE):
31
+ def __init__(self):
32
+ # Download from the Hub
33
+ self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1"
34
+ self.model = SentenceTransformer(self.CLIP_MODEL_ID)
35
+
36
+ def get_word_importance(self, PROMPT):
37
+ words, omit_prompts = create_splits(PROMPT)
38
+
39
+ sentences = [PROMPT] + omit_prompts
40
+
41
+ embeddings = self.model.encode(sentences)
42
+
43
+ similarities = self.model.similarity(embeddings[0:1], embeddings)
44
+
45
+ x = similarities[0]
46
+ x = -x.log() # importance of a word is the inverse of similarity-when-dropped
47
+ x = x - x[0] # subtract self-similarity as the baseline
48
+
49
+ x = x.clamp(0)
50
+ x /= x.max()
51
+ return x[1:], words
52
+
53
+ def get_word_importance_chunked(self, PROMPT):
54
+ return self.get_word_importance(PROMPT)
55
+
56
+
57
+ # %%
58
+
59
+
60
+ def compute_static_word_importances(
61
+ f: Path, ie: ImportanceEvaluatorStatic, overwrite=False
62
+ ):
63
+ model_id = ie.CLIP_MODEL_ID
64
+ for c in f.glob(".captions/*"):
65
+ metadir = c / ".meta"
66
+ for file in c.iterdir():
67
+ if file.suffix == ".txt" and file.is_file():
68
+ # print(file)
69
+ try:
70
+ out = metadir / file.with_suffix(".pth").name
71
+ r = {}
72
+ if out.exists():
73
+ r = torch.load(out, weights_only=False)
74
+ assert isinstance(r, dict), "corrupt format"
75
+ if (not overwrite) and (model_id in r):
76
+ continue
77
+
78
+ caption = file.read_text()
79
+ if (model_id not in r) or overwrite:
80
+ importances = [
81
+ ie.get_word_importance_chunked(l) if l else None
82
+ for l in caption.split("\n")
83
+ ]
84
+ r[model_id] = importances
85
+
86
+ metadir.mkdir(exist_ok=True)
87
+ torch.save(r, out)
88
+ except Exception as e:
89
+ print("ERROR", out, e)
90
+
91
+
92
+ def yield_dirs(root: Path):
93
+ for subset in root.iterdir():
94
+ if not subset.is_dir():
95
+ if subset.name.startswith("."):
96
+ continue
97
+ yield subset
98
+
99
+
100
+ if __name__ == "__main__":
101
+ ies = ImportanceEvaluatorStatic()
102
+ root = Path("/path_to_my_files")
103
+ dfs = []
104
+ from tqdm import tqdm
105
+ pb = tqdm()
106
+ for f in yield_dirs(root, True):
107
+ pb.update(1)
108
+ print(f)
109
+ compute_static_word_importances(f, ies, overwrite=False)
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib
5
+ matplotlib.use("Agg")
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.patches as mpatches
8
+ from matplotlib.colors import LinearSegmentedColormap
9
+ from sentence_transformers import SentenceTransformer
10
+ from abc import ABC, abstractmethod
11
+ import io
12
+ from PIL import Image
13
+
14
+
15
+ # ─────────────────────────────────────────────
16
+ # Core importance evaluator (unchanged logic)
17
+ # ─────────────────────────────────────────────
18
+
19
+ def create_splits(p):
20
+ words = p.split()
21
+ omit_prompts = [
22
+ " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words))
23
+ ]
24
+ return words, omit_prompts
25
+
26
+
27
+ class IE(ABC):
28
+ @abstractmethod
29
+ def get_word_importance_chunked(self, PROMPT):
30
+ pass
31
+
32
+
33
+ class ImportanceEvaluatorStatic(IE):
34
+ def __init__(self):
35
+ self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1"
36
+ self.model = SentenceTransformer(self.CLIP_MODEL_ID)
37
+
38
+ def get_word_importance(self, PROMPT):
39
+ words, omit_prompts = create_splits(PROMPT)
40
+ sentences = [PROMPT] + omit_prompts
41
+ embeddings = self.model.encode(sentences)
42
+ similarities = self.model.similarity(embeddings[0:1], embeddings)
43
+ x = similarities[0]
44
+ x = -x.log()
45
+ x = x - x[0]
46
+ x = x.clamp(0)
47
+ if x.max() > 0:
48
+ x /= x.max()
49
+ return x[1:], words
50
+
51
+ def get_word_importance_chunked(self, PROMPT):
52
+ return self.get_word_importance(PROMPT)
53
+
54
+ def get_caption_embedding(self, PROMPT):
55
+ return self.model.encode(PROMPT)
56
+
57
+
58
+ # ─────────────────────────────────────────────
59
+ # Load model once at startup
60
+ # ─────────────────────────────────────────────
61
+
62
+ _ie = None
63
+
64
+ def get_evaluator():
65
+ global _ie
66
+ if _ie is None:
67
+ _ie = ImportanceEvaluatorStatic()
68
+ return _ie
69
+
70
+
71
+ # ─────────────────────────────────────────────
72
+ # Plotting helpers
73
+ # ─────────────────────────────────────────────
74
+
75
+ PALETTE = {
76
+ "bg": "#0d0f14",
77
+ "panel": "#14171f",
78
+ "border": "#1e2330",
79
+ "accent": "#e8c547",
80
+ "accent2": "#5bc4c0",
81
+ "text": "#d4d8e8",
82
+ "muted": "#5a6080",
83
+ "low": "#2a3a5c",
84
+ "mid": "#4a7c8c",
85
+ "high": "#e8c547",
86
+ "critical": "#e85f47",
87
+ }
88
+
89
+ CMAP = LinearSegmentedColormap.from_list(
90
+ "imp", ["#2a3a5c", "#5bc4c0", "#e8c547", "#e85f47"], N=256
91
+ )
92
+
93
+ def _fig_to_pil(fig):
94
+ buf = io.BytesIO()
95
+ fig.savefig(buf, format="png", dpi=150, bbox_inches="tight",
96
+ facecolor=PALETTE["bg"])
97
+ buf.seek(0)
98
+ img = Image.open(buf).copy()
99
+ buf.close()
100
+ plt.close(fig)
101
+ return img
102
+
103
+
104
+ def plot_importance_bars(words, importances, threshold=0.3):
105
+ """Horizontal bar chart coloured by importance with threshold line."""
106
+ n = len(words)
107
+ fig_h = max(3.5, n * 0.38)
108
+ fig, ax = plt.subplots(figsize=(9, fig_h), facecolor=PALETTE["bg"])
109
+ ax.set_facecolor(PALETTE["panel"])
110
+
111
+ vals = np.array(importances)
112
+ colors = [CMAP(float(v)) for v in vals]
113
+
114
+ bars = ax.barh(range(n), vals, color=colors, edgecolor=PALETTE["border"],
115
+ linewidth=0.6, height=0.65)
116
+
117
+ # threshold line
118
+ ax.axvline(threshold, color=PALETTE["accent"], linewidth=1.4,
119
+ linestyle="--", alpha=0.85, label=f"threshold = {threshold:.2f}")
120
+
121
+ # word labels
122
+ ax.set_yticks(range(n))
123
+ ax.set_yticklabels(words, fontsize=10, color=PALETTE["text"],
124
+ fontfamily="monospace")
125
+ ax.invert_yaxis()
126
+
127
+ # value annotations
128
+ for i, (bar, v) in enumerate(zip(bars, vals)):
129
+ marker = "▶" if v >= threshold else ""
130
+ ax.text(min(v + 0.02, 1.05), i, f"{v:.3f} {marker}",
131
+ va="center", fontsize=8.5,
132
+ color=PALETTE["accent"] if v >= threshold else PALETTE["muted"])
133
+
134
+ ax.set_xlim(0, 1.18)
135
+ ax.set_xlabel("Normalised importance", color=PALETTE["text"], fontsize=10)
136
+ ax.set_title("Word Importance · drop-one analysis", color=PALETTE["text"],
137
+ fontsize=12, fontweight="bold", pad=10)
138
+
139
+ ax.tick_params(colors=PALETTE["muted"], which="both")
140
+ for spine in ax.spines.values():
141
+ spine.set_edgecolor(PALETTE["border"])
142
+
143
+ ax.legend(facecolor=PALETTE["panel"], edgecolor=PALETTE["border"],
144
+ labelcolor=PALETTE["accent"], fontsize=9)
145
+
146
+ fig.tight_layout(pad=1.2)
147
+ return _fig_to_pil(fig)
148
+
149
+
150
+ def sample_prompts(words, importances, n_samples=8, seed=42):
151
+ """
152
+ Each word is included in a sample with probability == its importance score.
153
+ Returns HTML showing N sampled prompts, with included words highlighted
154
+ by their importance colour and dropped words shown as dim strikethrough.
155
+ """
156
+ rng = np.random.default_rng(seed)
157
+ vals = np.array(importances, dtype=float)
158
+
159
+ def imp_to_hex(v):
160
+ r, g, b, _ = CMAP(float(v))
161
+ return "#{:02x}{:02x}{:02x}".format(int(r*255), int(g*255), int(b*255))
162
+
163
+ rows_html = []
164
+ for s in range(n_samples):
165
+ mask = rng.random(len(words)) < vals # Bernoulli draw
166
+ word_spans = []
167
+ for word, keep, v in zip(words, mask, vals):
168
+ color = imp_to_hex(v)
169
+ if keep:
170
+ span = (
171
+ f'<span style="color:{color};font-weight:600;'
172
+ f'font-family:monospace;padding:0 1px;">{word}</span>'
173
+ )
174
+ else:
175
+ span = (
176
+ f'<span style="color:{PALETTE["border"]};'
177
+ f'text-decoration:line-through;font-family:monospace;'
178
+ f'padding:0 1px;">{word}</span>'
179
+ )
180
+ word_spans.append(span)
181
+
182
+ kept_count = int(mask.sum())
183
+ row = (
184
+ f'<div style="margin-bottom:10px;padding:8px 12px;'
185
+ f'background:{PALETTE["bg"]};border-left:3px solid {PALETTE["border"]};'
186
+ f'border-radius:0 6px 6px 0;">'
187
+ f'<span style="color:{PALETTE["muted"]};font-size:11px;'
188
+ f'font-family:monospace;margin-right:10px;">#{s+1} '
189
+ f'({kept_count}/{len(words)})</span>'
190
+ + " ".join(word_spans)
191
+ + "</div>"
192
+ )
193
+ rows_html.append(row)
194
+
195
+ # legend
196
+ legend_stops = [0.0, 0.33, 0.66, 1.0]
197
+ legend_html = "".join(
198
+ f'<span style="color:{imp_to_hex(v)};font-family:monospace;'
199
+ f'font-size:11px;margin-right:8px;">▮ {v:.0%}</span>'
200
+ for v in legend_stops
201
+ )
202
+
203
+ html = (
204
+ f'<div style="background:{PALETTE["panel"]};padding:16px 20px;'
205
+ f'border-radius:8px;border:1px solid {PALETTE["border"]};">'
206
+ f'<div style="margin-bottom:12px;color:{PALETTE["muted"]};font-size:12px;'
207
+ f'font-family:monospace;">importance colour scale: {legend_html}</div>'
208
+ + "".join(rows_html)
209
+ + "</div>"
210
+ )
211
+ return html
212
+
213
+
214
+ def build_threshold_output(words, importances, threshold):
215
+ """Return highlighted HTML and plain text for above-threshold words."""
216
+ lines = []
217
+ above = []
218
+ for word, imp in zip(words, importances):
219
+ if imp >= threshold:
220
+ above.append(word)
221
+ style = (f"background:{PALETTE['accent']}22;"
222
+ f"color:{PALETTE['accent']};"
223
+ "border-radius:3px;padding:1px 4px;"
224
+ "font-weight:700;font-family:monospace;")
225
+ else:
226
+ style = f"color:{PALETTE['muted']};font-family:monospace;"
227
+ lines.append(f'<span style="{style}">{word}</span>')
228
+
229
+ highlighted = (
230
+ f'<div style="background:{PALETTE["panel"]};padding:16px 20px;'
231
+ f'border-radius:8px;border:1px solid {PALETTE["border"]};'
232
+ f'line-height:2.1;font-size:15px;">'
233
+ + " ".join(lines)
234
+ + "</div>"
235
+ )
236
+
237
+ summary = (
238
+ f"**{len(above)} / {len(words)} words** above threshold {threshold:.2f}:\n\n"
239
+ + ", ".join(f"`{w}`" for w in above) if above else
240
+ "_No words exceed the threshold._"
241
+ )
242
+ return highlighted, summary
243
+
244
+
245
+ # ─────────────────────────────────────────────
246
+ # Main inference function
247
+ # ─────────────────────────────────────────────
248
+
249
+ def analyse(prompt: str, threshold: float, n_samples: int):
250
+ prompt = prompt.strip()
251
+ if not prompt:
252
+ return None, "<p>Please enter a prompt.</p>", "", "<p></p>"
253
+
254
+ ie = get_evaluator()
255
+
256
+ lines = [l for l in prompt.split("\n") if l.strip()]
257
+ all_words, all_imps = [], []
258
+ for line in lines:
259
+ result = ie.get_word_importance_chunked(line)
260
+ if result is not None:
261
+ imps, words = result
262
+ all_words.extend(words)
263
+ all_imps.extend(imps.tolist())
264
+
265
+ if not all_words:
266
+ return None, "<p>Could not parse prompt.</p>", "", "<p></p>"
267
+
268
+ bar_img = plot_importance_bars(all_words, all_imps, threshold)
269
+ highlighted, summary = build_threshold_output(all_words, all_imps, threshold)
270
+ samples_html = sample_prompts(all_words, all_imps, n_samples=n_samples)
271
+
272
+ return bar_img, highlighted, summary, samples_html
273
+
274
+
275
+ # ─────────────────────────────────────────────
276
+ # Gradio UI
277
+ # ─────────────────────────────────────────────
278
+
279
+ CSS = f"""
280
+ @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;600&display=swap');
281
+
282
+ body, .gradio-container {{
283
+ background: {PALETTE['bg']} !important;
284
+ font-family: 'DM Sans', sans-serif !important;
285
+ color: {PALETTE['text']} !important;
286
+ }}
287
+
288
+ .gr-panel, .gr-box, .gr-form {{
289
+ background: {PALETTE['panel']} !important;
290
+ border: 1px solid {PALETTE['border']} !important;
291
+ border-radius: 10px !important;
292
+ }}
293
+
294
+ h1, h2, h3 {{
295
+ font-family: 'Space Mono', monospace !important;
296
+ color: {PALETTE['accent']} !important;
297
+ letter-spacing: -0.5px !important;
298
+ }}
299
+
300
+ .gr-button-primary {{
301
+ background: {PALETTE['accent']} !important;
302
+ color: {PALETTE['bg']} !important;
303
+ font-family: 'Space Mono', monospace !important;
304
+ font-weight: 700 !important;
305
+ border: none !important;
306
+ border-radius: 6px !important;
307
+ }}
308
+
309
+ .gr-button-primary:hover {{
310
+ opacity: 0.85 !important;
311
+ }}
312
+
313
+ label {{
314
+ color: {PALETTE['text']} !important;
315
+ font-size: 13px !important;
316
+ font-family: 'Space Mono', monospace !important;
317
+ }}
318
+
319
+ textarea, input[type=text] {{
320
+ background: {PALETTE['bg']} !important;
321
+ color: {PALETTE['text']} !important;
322
+ border: 1px solid {PALETTE['border']} !important;
323
+ font-family: 'Space Mono', monospace !important;
324
+ font-size: 13px !important;
325
+ }}
326
+
327
+ .markdown-text {{
328
+ color: {PALETTE['text']} !important;
329
+ }}
330
+ """
331
+
332
+ DESCRIPTION = """
333
+ # 🔬 Word Importance Evaluator
334
+
335
+ Drop-one embedding analysis using **static-retrieval-mrl-en-v1**.
336
+ Each word's importance = semantic distance introduced by omitting it.
337
+
338
+ - **Bar chart** — ranked importance with threshold line
339
+ - **Threshold filter** — words above cutoff highlighted
340
+ - **Sampled prompts** — each word included with probability = its importance score
341
+ """
342
+
343
+ with gr.Blocks(css=CSS, title="Word Importance Evaluator") as demo:
344
+ gr.Markdown(DESCRIPTION)
345
+
346
+ with gr.Row():
347
+ with gr.Column(scale=2):
348
+ prompt_box = gr.Textbox(
349
+ label="Prompt",
350
+ placeholder="a majestic lion in golden hour light, photorealistic, dramatic shadows",
351
+ lines=4,
352
+ )
353
+ with gr.Row():
354
+ threshold_slider = gr.Slider(
355
+ minimum=0.0, maximum=1.0, value=0.3, step=0.01,
356
+ label="Importance threshold",
357
+ )
358
+ n_samples_slider = gr.Slider(
359
+ minimum=1, maximum=20, value=8, step=1,
360
+ label="Number of sampled prompts",
361
+ )
362
+ run_btn = gr.Button("Analyse →", variant="primary")
363
+
364
+ with gr.Column(scale=1):
365
+ threshold_html = gr.HTML(label="Threshold output")
366
+ threshold_md = gr.Markdown(label="Summary")
367
+
368
+ bar_img = gr.Image(label="Importance bar chart", type="pil")
369
+
370
+ gr.Markdown("### 🎲 Sampled prompts *(each word kept with p = importance)*")
371
+ samples_html = gr.HTML(label="Sampled prompts")
372
+
373
+ run_btn.click(
374
+ fn=analyse,
375
+ inputs=[prompt_box, threshold_slider, n_samples_slider],
376
+ outputs=[bar_img, threshold_html, threshold_md, samples_html],
377
+ )
378
+
379
+ gr.Examples(
380
+ examples=[
381
+ ["a majestic lion in golden hour light, photorealistic, dramatic shadows", 0.3, 8],
382
+ ["cinematic portrait of a young woman, soft bokeh, rim lighting, film grain", 0.25, 8],
383
+ ["hyperrealistic macro photograph of a dewdrop on a spider web at dawn", 0.35, 10],
384
+ ["oil painting of a medieval castle surrounded by autumn forest", 0.3, 8],
385
+ ],
386
+ inputs=[prompt_box, threshold_slider, n_samples_slider],
387
+ fn=analyse,
388
+ outputs=[bar_img, threshold_html, threshold_md, samples_html],
389
+ cache_examples=False,
390
+ )
391
+
392
+ demo.launch()
app_v1.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib
5
+ matplotlib.use("Agg")
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.patches as mpatches
8
+ from matplotlib.colors import LinearSegmentedColormap
9
+ from sentence_transformers import SentenceTransformer
10
+ from abc import ABC, abstractmethod
11
+ import io
12
+ from PIL import Image
13
+
14
+
15
+ # ─────────────────────────────────────────────
16
+ # Core importance evaluator (unchanged logic)
17
+ # ─────────────────────────────────────────────
18
+
19
+ def create_splits(p):
20
+ words = p.split()
21
+ omit_prompts = [
22
+ " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words))
23
+ ]
24
+ return words, omit_prompts
25
+
26
+
27
+ class IE(ABC):
28
+ @abstractmethod
29
+ def get_word_importance_chunked(self, PROMPT):
30
+ pass
31
+
32
+
33
+ class ImportanceEvaluatorStatic(IE):
34
+ def __init__(self):
35
+ self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1"
36
+ self.model = SentenceTransformer(self.CLIP_MODEL_ID)
37
+
38
+ def get_word_importance(self, PROMPT):
39
+ words, omit_prompts = create_splits(PROMPT)
40
+ sentences = [PROMPT] + omit_prompts
41
+ embeddings = self.model.encode(sentences)
42
+ similarities = self.model.similarity(embeddings[0:1], embeddings)
43
+ x = similarities[0]
44
+ x = -x.log()
45
+ x = x - x[0]
46
+ x = x.clamp(0)
47
+ if x.max() > 0:
48
+ x /= x.max()
49
+ return x[1:], words
50
+
51
+ def get_word_importance_chunked(self, PROMPT):
52
+ return self.get_word_importance(PROMPT)
53
+
54
+ def get_caption_embedding(self, PROMPT):
55
+ return self.model.encode(PROMPT)
56
+
57
+
58
+ # ─────────────────────────────────────────────
59
+ # Load model once at startup
60
+ # ─────────────────────────────────────────────
61
+
62
+ _ie = None
63
+
64
+ def get_evaluator():
65
+ global _ie
66
+ if _ie is None:
67
+ _ie = ImportanceEvaluatorStatic()
68
+ return _ie
69
+
70
+
71
+ # ─────────────────────────────────────────────
72
+ # Plotting helpers
73
+ # ─────────────────────────────────────────────
74
+
75
+ PALETTE = {
76
+ "bg": "#0d0f14",
77
+ "panel": "#14171f",
78
+ "border": "#1e2330",
79
+ "accent": "#e8c547",
80
+ "accent2": "#5bc4c0",
81
+ "text": "#d4d8e8",
82
+ "muted": "#5a6080",
83
+ "low": "#2a3a5c",
84
+ "mid": "#4a7c8c",
85
+ "high": "#e8c547",
86
+ "critical": "#e85f47",
87
+ }
88
+
89
+ CMAP = LinearSegmentedColormap.from_list(
90
+ "imp", ["#2a3a5c", "#5bc4c0", "#e8c547", "#e85f47"], N=256
91
+ )
92
+
93
+ def _fig_to_pil(fig):
94
+ buf = io.BytesIO()
95
+ fig.savefig(buf, format="png", dpi=150, bbox_inches="tight",
96
+ facecolor=PALETTE["bg"])
97
+ buf.seek(0)
98
+ img = Image.open(buf).copy()
99
+ buf.close()
100
+ plt.close(fig)
101
+ return img
102
+
103
+
104
+ def plot_importance_bars(words, importances, threshold=0.3):
105
+ """Horizontal bar chart coloured by importance with threshold line."""
106
+ n = len(words)
107
+ fig_h = max(3.5, n * 0.38)
108
+ fig, ax = plt.subplots(figsize=(9, fig_h), facecolor=PALETTE["bg"])
109
+ ax.set_facecolor(PALETTE["panel"])
110
+
111
+ vals = np.array(importances)
112
+ colors = [CMAP(float(v)) for v in vals]
113
+
114
+ bars = ax.barh(range(n), vals, color=colors, edgecolor=PALETTE["border"],
115
+ linewidth=0.6, height=0.65)
116
+
117
+ # threshold line
118
+ ax.axvline(threshold, color=PALETTE["accent"], linewidth=1.4,
119
+ linestyle="--", alpha=0.85, label=f"threshold = {threshold:.2f}")
120
+
121
+ # word labels
122
+ ax.set_yticks(range(n))
123
+ ax.set_yticklabels(words, fontsize=10, color=PALETTE["text"],
124
+ fontfamily="monospace")
125
+ ax.invert_yaxis()
126
+
127
+ # value annotations
128
+ for i, (bar, v) in enumerate(zip(bars, vals)):
129
+ marker = "▶" if v >= threshold else ""
130
+ ax.text(min(v + 0.02, 1.05), i, f"{v:.3f} {marker}",
131
+ va="center", fontsize=8.5,
132
+ color=PALETTE["accent"] if v >= threshold else PALETTE["muted"])
133
+
134
+ ax.set_xlim(0, 1.18)
135
+ ax.set_xlabel("Normalised importance", color=PALETTE["text"], fontsize=10)
136
+ ax.set_title("Word Importance · drop-one analysis", color=PALETTE["text"],
137
+ fontsize=12, fontweight="bold", pad=10)
138
+
139
+ ax.tick_params(colors=PALETTE["muted"], which="both")
140
+ for spine in ax.spines.values():
141
+ spine.set_edgecolor(PALETTE["border"])
142
+
143
+ ax.legend(facecolor=PALETTE["panel"], edgecolor=PALETTE["border"],
144
+ labelcolor=PALETTE["accent"], fontsize=9)
145
+
146
+ fig.tight_layout(pad=1.2)
147
+ return _fig_to_pil(fig)
148
+
149
+
150
+ def plot_distribution(words, importances, n_samples=2000, seed=42):
151
+ """
152
+ Simulate distribution per word by adding Gaussian jitter
153
+ (approximates the spread one would see across paraphrase variants).
154
+ Shows violin / scatter strip.
155
+ """
156
+ rng = np.random.default_rng(seed)
157
+ n = len(words)
158
+
159
+ fig, ax = plt.subplots(figsize=(max(6, n * 0.7 + 1), 5),
160
+ facecolor=PALETTE["bg"])
161
+ ax.set_facecolor(PALETTE["panel"])
162
+
163
+ vals = np.array(importances, dtype=float)
164
+
165
+ for i, (word, v) in enumerate(zip(words, vals)):
166
+ # Jitter width proportional to value (higher = wider spread)
167
+ sigma = 0.04 + 0.08 * v
168
+ samples = rng.normal(loc=v, scale=sigma, size=n_samples).clip(0, 1)
169
+
170
+ # violin-like fill via histogram
171
+ hist, edges = np.histogram(samples, bins=40, density=True)
172
+ hist_norm = hist / hist.max() * 0.38
173
+ centers = (edges[:-1] + edges[1:]) / 2
174
+
175
+ color = CMAP(float(v))
176
+ ax.fill_betweenx(centers, i - hist_norm, i + hist_norm,
177
+ color=color, alpha=0.55, linewidth=0)
178
+ ax.plot([i - hist_norm, i + hist_norm],
179
+ [centers, centers], color=color, alpha=0.05, linewidth=0.3)
180
+
181
+ # median line
182
+ ax.hlines(v, i - 0.35, i + 0.35, colors=PALETTE["accent"],
183
+ linewidth=1.6, zorder=5)
184
+ # dot
185
+ ax.scatter([i], [v], color=PALETTE["accent"], s=28, zorder=6)
186
+
187
+ ax.set_xticks(range(n))
188
+ ax.set_xticklabels(words, rotation=35, ha="right", fontsize=9,
189
+ color=PALETTE["text"], fontfamily="monospace")
190
+ ax.set_ylabel("Importance", color=PALETTE["text"], fontsize=10)
191
+ ax.set_title("Per-word Importance Distribution (sampled spread)",
192
+ color=PALETTE["text"], fontsize=12, fontweight="bold", pad=10)
193
+ ax.set_ylim(-0.05, 1.12)
194
+
195
+ ax.tick_params(colors=PALETTE["muted"])
196
+ for spine in ax.spines.values():
197
+ spine.set_edgecolor(PALETTE["border"])
198
+
199
+ fig.tight_layout(pad=1.2)
200
+ return _fig_to_pil(fig)
201
+
202
+
203
+ def build_threshold_output(words, importances, threshold):
204
+ """Return highlighted HTML and plain text for above-threshold words."""
205
+ lines = []
206
+ above = []
207
+ for word, imp in zip(words, importances):
208
+ if imp >= threshold:
209
+ above.append(word)
210
+ style = (f"background:{PALETTE['accent']}22;"
211
+ f"color:{PALETTE['accent']};"
212
+ "border-radius:3px;padding:1px 4px;"
213
+ "font-weight:700;font-family:monospace;")
214
+ else:
215
+ style = f"color:{PALETTE['muted']};font-family:monospace;"
216
+ lines.append(f'<span style="{style}">{word}</span>')
217
+
218
+ highlighted = (
219
+ f'<div style="background:{PALETTE["panel"]};padding:16px 20px;'
220
+ f'border-radius:8px;border:1px solid {PALETTE["border"]};'
221
+ f'line-height:2.1;font-size:15px;">'
222
+ + " ".join(lines)
223
+ + "</div>"
224
+ )
225
+
226
+ summary = (
227
+ f"**{len(above)} / {len(words)} words** above threshold {threshold:.2f}:\n\n"
228
+ + ", ".join(f"`{w}`" for w in above) if above else
229
+ "_No words exceed the threshold._"
230
+ )
231
+ return highlighted, summary
232
+
233
+
234
+ # ─────────────────────────────────────────────
235
+ # Main inference function
236
+ # ─────────────────────────────────────────────
237
+
238
+ def analyse(prompt: str, threshold: float, n_dist_samples: int):
239
+ prompt = prompt.strip()
240
+ if not prompt:
241
+ return None, None, "<p>Please enter a prompt.</p>", ""
242
+
243
+ ie = get_evaluator()
244
+
245
+ # Compute per-line importances (multi-line support)
246
+ lines = [l for l in prompt.split("\n") if l.strip()]
247
+ all_words, all_imps = [], []
248
+ for line in lines:
249
+ result = ie.get_word_importance_chunked(line)
250
+ if result is not None:
251
+ imps, words = result
252
+ all_words.extend(words)
253
+ all_imps.extend(imps.tolist())
254
+
255
+ if not all_words:
256
+ return None, None, "<p>Could not parse prompt.</p>", ""
257
+
258
+ bar_img = plot_importance_bars(all_words, all_imps, threshold)
259
+ dist_img = plot_distribution(all_words, all_imps, n_samples=n_dist_samples)
260
+ highlighted, summary = build_threshold_output(all_words, all_imps, threshold)
261
+
262
+ return bar_img, dist_img, highlighted, summary
263
+
264
+
265
+ # ─────────────────────────────────────────────
266
+ # Gradio UI
267
+ # ─────────────────────────────────────────────
268
+
269
+ CSS = f"""
270
+ @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;600&display=swap');
271
+
272
+ body, .gradio-container {{
273
+ background: {PALETTE['bg']} !important;
274
+ font-family: 'DM Sans', sans-serif !important;
275
+ color: {PALETTE['text']} !important;
276
+ }}
277
+
278
+ .gr-panel, .gr-box, .gr-form {{
279
+ background: {PALETTE['panel']} !important;
280
+ border: 1px solid {PALETTE['border']} !important;
281
+ border-radius: 10px !important;
282
+ }}
283
+
284
+ h1, h2, h3 {{
285
+ font-family: 'Space Mono', monospace !important;
286
+ color: {PALETTE['accent']} !important;
287
+ letter-spacing: -0.5px !important;
288
+ }}
289
+
290
+ .gr-button-primary {{
291
+ background: {PALETTE['accent']} !important;
292
+ color: {PALETTE['bg']} !important;
293
+ font-family: 'Space Mono', monospace !important;
294
+ font-weight: 700 !important;
295
+ border: none !important;
296
+ border-radius: 6px !important;
297
+ }}
298
+
299
+ .gr-button-primary:hover {{
300
+ opacity: 0.85 !important;
301
+ }}
302
+
303
+ label {{
304
+ color: {PALETTE['text']} !important;
305
+ font-size: 13px !important;
306
+ font-family: 'Space Mono', monospace !important;
307
+ }}
308
+
309
+ textarea, input[type=text] {{
310
+ background: {PALETTE['bg']} !important;
311
+ color: {PALETTE['text']} !important;
312
+ border: 1px solid {PALETTE['border']} !important;
313
+ font-family: 'Space Mono', monospace !important;
314
+ font-size: 13px !important;
315
+ }}
316
+
317
+ .markdown-text {{
318
+ color: {PALETTE['text']} !important;
319
+ }}
320
+ """
321
+
322
+ DESCRIPTION = """
323
+ # 🔬 Word Importance Evaluator
324
+
325
+ Drop-one embedding analysis using **static-retrieval-mrl-en-v1**.
326
+ Each word's importance = semantic distance introduced by omitting it.
327
+
328
+ Enter a prompt (multi-line supported), adjust the threshold, and explore:
329
+ - **Bar chart** — ranked importance per word
330
+ - **Distribution** — sampled spread per word
331
+ - **Threshold filter** — highlight words above cutoff
332
+ """
333
+
334
+ with gr.Blocks(css=CSS, title="Word Importance Evaluator") as demo:
335
+ gr.Markdown(DESCRIPTION)
336
+
337
+ with gr.Row():
338
+ with gr.Column(scale=2):
339
+ prompt_box = gr.Textbox(
340
+ label="Prompt",
341
+ placeholder="a majestic lion in golden hour light, photorealistic, dramatic shadows",
342
+ lines=4,
343
+ )
344
+ with gr.Row():
345
+ threshold_slider = gr.Slider(
346
+ minimum=0.0, maximum=1.0, value=0.3, step=0.01,
347
+ label="Importance threshold",
348
+ )
349
+ n_samples_slider = gr.Slider(
350
+ minimum=200, maximum=5000, value=1500, step=100,
351
+ label="Distribution samples per word",
352
+ )
353
+ run_btn = gr.Button("Analyse →", variant="primary")
354
+
355
+ with gr.Column(scale=1):
356
+ threshold_html = gr.HTML(label="Threshold output")
357
+ threshold_md = gr.Markdown(label="Summary")
358
+
359
+ with gr.Row():
360
+ bar_img = gr.Image(label="Importance bar chart", type="pil", height=500)
361
+ dist_img = gr.Image(label="Distribution per word", type="pil", height=500)
362
+
363
+ run_btn.click(
364
+ fn=analyse,
365
+ inputs=[prompt_box, threshold_slider, n_samples_slider],
366
+ outputs=[bar_img, dist_img, threshold_html, threshold_md],
367
+ )
368
+
369
+ gr.Examples(
370
+ examples=[
371
+ ["a majestic lion in golden hour light, photorealistic, dramatic shadows", 0.3, 1500],
372
+ ["cinematic portrait of a young woman, soft bokeh, rim lighting, film grain", 0.25, 1500],
373
+ ["hyperrealistic macro photograph of a dewdrop on a spider web at dawn", 0.35, 2000],
374
+ ["oil painting of a medieval castle surrounded by autumn forest", 0.3, 1500],
375
+ ],
376
+ inputs=[prompt_box, threshold_slider, n_samples_slider],
377
+ fn=analyse,
378
+ outputs=[bar_img, dist_img, threshold_html, threshold_md],
379
+ cache_examples=False,
380
+ )
381
+
382
+ demo.launch()
prompt.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ turn the importance evaluator into a huggingface space. keep the relevant code unchanged. output should be importance barcharts and sample outputs with thresholding as well as distribution sampling per word
2
+ --------------
3
+ by distribution sampling i mean an output text where the importances are used as probabilities and they are included randomly according to that probability
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ torch>=2.0.0
3
+ sentence-transformers>=3.0.0
4
+ numpy>=1.24.0
5
+ matplotlib>=3.7.0
6
+ Pillow>=10.0.0