import gradio as gr import torch import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.colors import LinearSegmentedColormap from sentence_transformers import SentenceTransformer from abc import ABC, abstractmethod import io from PIL import Image # ───────────────────────────────────────────── # Core importance evaluator (unchanged logic) # ───────────────────────────────────────────── def create_splits(p): words = p.split() omit_prompts = [ " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words)) ] return words, omit_prompts class IE(ABC): @abstractmethod def get_word_importance_chunked(self, PROMPT): pass class ImportanceEvaluatorStatic(IE): def __init__(self): self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1" self.model = SentenceTransformer(self.CLIP_MODEL_ID) def get_word_importance(self, PROMPT): words, omit_prompts = create_splits(PROMPT) sentences = [PROMPT] + omit_prompts embeddings = self.model.encode(sentences) similarities = self.model.similarity(embeddings[0:1], embeddings) x = similarities[0] x = -x.log() x = x - x[0] x = x.clamp(0) if x.max() > 0: x /= x.max() return x[1:], words def get_word_importance_chunked(self, PROMPT): return self.get_word_importance(PROMPT) def get_caption_embedding(self, PROMPT): return self.model.encode(PROMPT) # ───────────────────────────────────────────── # Load model once at startup # ───────────────────────────────────────────── _ie = None def get_evaluator(): global _ie if _ie is None: _ie = ImportanceEvaluatorStatic() return _ie # ───────────────────────────────────────────── # Plotting helpers # ───────────────────────────────────────────── PALETTE = { "bg": "#0d0f14", "panel": "#14171f", "border": "#1e2330", "accent": "#e8c547", "accent2": "#5bc4c0", "text": "#d4d8e8", "muted": "#5a6080", "low": "#2a3a5c", "mid": "#4a7c8c", "high": "#e8c547", "critical": "#e85f47", } CMAP = LinearSegmentedColormap.from_list( "imp", ["#2a3a5c", "#5bc4c0", "#e8c547", "#e85f47"], N=256 ) def _fig_to_pil(fig): buf = io.BytesIO() fig.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor=PALETTE["bg"]) buf.seek(0) img = Image.open(buf).copy() buf.close() plt.close(fig) return img def plot_importance_bars(words, importances, threshold=0.3): """Horizontal bar chart coloured by importance with threshold line.""" n = len(words) fig_h = max(3.5, n * 0.38) fig, ax = plt.subplots(figsize=(9, fig_h), facecolor=PALETTE["bg"]) ax.set_facecolor(PALETTE["panel"]) vals = np.array(importances) colors = [CMAP(float(v)) for v in vals] bars = ax.barh(range(n), vals, color=colors, edgecolor=PALETTE["border"], linewidth=0.6, height=0.65) # threshold line ax.axvline(threshold, color=PALETTE["accent"], linewidth=1.4, linestyle="--", alpha=0.85, label=f"threshold = {threshold:.2f}") # word labels ax.set_yticks(range(n)) ax.set_yticklabels(words, fontsize=10, color=PALETTE["text"], fontfamily="monospace") ax.invert_yaxis() # value annotations for i, (bar, v) in enumerate(zip(bars, vals)): marker = "▶" if v >= threshold else "" ax.text(min(v + 0.02, 1.05), i, f"{v:.3f} {marker}", va="center", fontsize=8.5, color=PALETTE["accent"] if v >= threshold else PALETTE["muted"]) ax.set_xlim(0, 1.18) ax.set_xlabel("Normalised importance", color=PALETTE["text"], fontsize=10) ax.set_title("Word Importance · drop-one analysis", color=PALETTE["text"], fontsize=12, fontweight="bold", pad=10) ax.tick_params(colors=PALETTE["muted"], which="both") for spine in ax.spines.values(): spine.set_edgecolor(PALETTE["border"]) ax.legend(facecolor=PALETTE["panel"], edgecolor=PALETTE["border"], labelcolor=PALETTE["accent"], fontsize=9) fig.tight_layout(pad=1.2) return _fig_to_pil(fig) def sample_prompts(words, importances, n_samples=8, seed=42): """ Each word is included in a sample with probability == its importance score. Returns HTML showing N sampled prompts, with included words highlighted by their importance colour and dropped words shown as dim strikethrough. """ rng = np.random.default_rng(seed) vals = np.array(importances, dtype=float) def imp_to_hex(v): r, g, b, _ = CMAP(float(v)) return "#{:02x}{:02x}{:02x}".format(int(r*255), int(g*255), int(b*255)) rows_html = [] for s in range(n_samples): mask = rng.random(len(words)) < vals # Bernoulli draw word_spans = [] for word, keep, v in zip(words, mask, vals): color = imp_to_hex(v) if keep: span = ( f'{word}' ) else: span = ( f'{word}' ) word_spans.append(span) kept_count = int(mask.sum()) row = ( f'
' f'#{s+1} ' f'({kept_count}/{len(words)})' + " ".join(word_spans) + "
" ) rows_html.append(row) # legend legend_stops = [0.0, 0.33, 0.66, 1.0] legend_html = "".join( f'▮ {v:.0%}' for v in legend_stops ) html = ( f'
' f'
importance colour scale: {legend_html}
' + "".join(rows_html) + "
" ) return html def build_threshold_output(words, importances, threshold): """Return highlighted HTML and plain text for above-threshold words.""" lines = [] above = [] for word, imp in zip(words, importances): if imp >= threshold: above.append(word) style = (f"background:{PALETTE['accent']}22;" f"color:{PALETTE['accent']};" "border-radius:3px;padding:1px 4px;" "font-weight:700;font-family:monospace;") else: style = f"color:{PALETTE['muted']};font-family:monospace;" lines.append(f'{word}') highlighted = ( f'
' + " ".join(lines) + "
" ) summary = ( f"**{len(above)} / {len(words)} words** above threshold {threshold:.2f}:\n\n" + ", ".join(f"`{w}`" for w in above) if above else "_No words exceed the threshold._" ) return highlighted, summary # ───────────────────────────────────────────── # Main inference function # ───────────────────────────────────────────── def analyse(prompt: str, threshold: float, n_samples: int): prompt = prompt.strip() if not prompt: return None, "

Please enter a prompt.

", "", "

" ie = get_evaluator() lines = [l for l in prompt.split("\n") if l.strip()] all_words, all_imps = [], [] for line in lines: result = ie.get_word_importance_chunked(line) if result is not None: imps, words = result all_words.extend(words) all_imps.extend(imps.tolist()) if not all_words: return None, "

Could not parse prompt.

", "", "

" bar_img = plot_importance_bars(all_words, all_imps, threshold) highlighted, summary = build_threshold_output(all_words, all_imps, threshold) samples_html = sample_prompts(all_words, all_imps, n_samples=n_samples) return bar_img, highlighted, summary, samples_html # ───────────────────────────────────────────── # Gradio UI # ───────────────────────────────────────────── CSS = f""" @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;600&display=swap'); body, .gradio-container {{ background: {PALETTE['bg']} !important; font-family: 'DM Sans', sans-serif !important; color: {PALETTE['text']} !important; }} .gr-panel, .gr-box, .gr-form {{ background: {PALETTE['panel']} !important; border: 1px solid {PALETTE['border']} !important; border-radius: 10px !important; }} h1, h2, h3 {{ font-family: 'Space Mono', monospace !important; color: {PALETTE['accent']} !important; letter-spacing: -0.5px !important; }} .gr-button-primary {{ background: {PALETTE['accent']} !important; color: {PALETTE['bg']} !important; font-family: 'Space Mono', monospace !important; font-weight: 700 !important; border: none !important; border-radius: 6px !important; }} .gr-button-primary:hover {{ opacity: 0.85 !important; }} label {{ color: {PALETTE['text']} !important; font-size: 13px !important; font-family: 'Space Mono', monospace !important; }} textarea, input[type=text] {{ background: {PALETTE['bg']} !important; color: {PALETTE['text']} !important; border: 1px solid {PALETTE['border']} !important; font-family: 'Space Mono', monospace !important; font-size: 13px !important; }} .markdown-text {{ color: {PALETTE['text']} !important; }} """ DESCRIPTION = """ # 🔬 Word Importance Evaluator Drop-one embedding analysis using **static-retrieval-mrl-en-v1**. Each word's importance = semantic distance introduced by omitting it. - **Bar chart** — ranked importance with threshold line - **Threshold filter** — words above cutoff highlighted - **Sampled prompts** — each word included with probability = its importance score """ with gr.Blocks(css=CSS, title="Word Importance Evaluator") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=2): prompt_box = gr.Textbox( label="Prompt", placeholder="a majestic lion in golden hour light, photorealistic, dramatic shadows", lines=4, ) with gr.Row(): threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Importance threshold", ) n_samples_slider = gr.Slider( minimum=1, maximum=20, value=8, step=1, label="Number of sampled prompts", ) run_btn = gr.Button("Analyse →", variant="primary") with gr.Column(scale=1): threshold_html = gr.HTML(label="Threshold output") threshold_md = gr.Markdown(label="Summary") bar_img = gr.Image(label="Importance bar chart", type="pil") gr.Markdown("### 🎲 Sampled prompts *(each word kept with p = importance)*") samples_html = gr.HTML(label="Sampled prompts") run_btn.click( fn=analyse, inputs=[prompt_box, threshold_slider, n_samples_slider], outputs=[bar_img, threshold_html, threshold_md, samples_html], ) gr.Examples( examples=[ ["a majestic lion in golden hour light, photorealistic, dramatic shadows", 0.3, 8], ["cinematic portrait of a young woman, soft bokeh, rim lighting, film grain", 0.25, 8], ["hyperrealistic macro photograph of a dewdrop on a spider web at dawn", 0.35, 10], ["oil painting of a medieval castle surrounded by autumn forest", 0.3, 8], ], inputs=[prompt_box, threshold_slider, n_samples_slider], fn=analyse, outputs=[bar_img, threshold_html, threshold_md, samples_html], cache_examples=False, ) demo.launch()