Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from sentence_transformers import SentenceTransformer | |
| from abc import ABC, abstractmethod | |
| import io | |
| from PIL import Image | |
| # ───────────────────────────────────────────── | |
| # Core importance evaluator (unchanged logic) | |
| # ───────────────────────────────────────────── | |
| def create_splits(p): | |
| words = p.split() | |
| omit_prompts = [ | |
| " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words)) | |
| ] | |
| return words, omit_prompts | |
| class IE(ABC): | |
| def get_word_importance_chunked(self, PROMPT): | |
| pass | |
| class ImportanceEvaluatorStatic(IE): | |
| def __init__(self): | |
| self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1" | |
| self.model = SentenceTransformer(self.CLIP_MODEL_ID) | |
| def get_word_importance(self, PROMPT): | |
| words, omit_prompts = create_splits(PROMPT) | |
| sentences = [PROMPT] + omit_prompts | |
| embeddings = self.model.encode(sentences) | |
| similarities = self.model.similarity(embeddings[0:1], embeddings) | |
| x = similarities[0] | |
| x = -x.log() | |
| x = x - x[0] | |
| x = x.clamp(0) | |
| if x.max() > 0: | |
| x /= x.max() | |
| return x[1:], words | |
| def get_word_importance_chunked(self, PROMPT): | |
| return self.get_word_importance(PROMPT) | |
| def get_caption_embedding(self, PROMPT): | |
| return self.model.encode(PROMPT) | |
| # ───────────────────────────────────────────── | |
| # Load model once at startup | |
| # ───────────────────────────────────────────── | |
| _ie = None | |
| def get_evaluator(): | |
| global _ie | |
| if _ie is None: | |
| _ie = ImportanceEvaluatorStatic() | |
| return _ie | |
| # ───────────────────────────────────────────── | |
| # Plotting helpers | |
| # ───────────────────────────────────────────── | |
| PALETTE = { | |
| "bg": "#0d0f14", | |
| "panel": "#14171f", | |
| "border": "#1e2330", | |
| "accent": "#e8c547", | |
| "accent2": "#5bc4c0", | |
| "text": "#d4d8e8", | |
| "muted": "#5a6080", | |
| "low": "#2a3a5c", | |
| "mid": "#4a7c8c", | |
| "high": "#e8c547", | |
| "critical": "#e85f47", | |
| } | |
| CMAP = LinearSegmentedColormap.from_list( | |
| "imp", ["#2a3a5c", "#5bc4c0", "#e8c547", "#e85f47"], N=256 | |
| ) | |
| def _fig_to_pil(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=150, bbox_inches="tight", | |
| facecolor=PALETTE["bg"]) | |
| buf.seek(0) | |
| img = Image.open(buf).copy() | |
| buf.close() | |
| plt.close(fig) | |
| return img | |
| def plot_importance_bars(words, importances, threshold=0.3): | |
| """Horizontal bar chart coloured by importance with threshold line.""" | |
| n = len(words) | |
| fig_h = max(3.5, n * 0.38) | |
| fig, ax = plt.subplots(figsize=(9, fig_h), facecolor=PALETTE["bg"]) | |
| ax.set_facecolor(PALETTE["panel"]) | |
| vals = np.array(importances) | |
| colors = [CMAP(float(v)) for v in vals] | |
| bars = ax.barh(range(n), vals, color=colors, edgecolor=PALETTE["border"], | |
| linewidth=0.6, height=0.65) | |
| # threshold line | |
| ax.axvline(threshold, color=PALETTE["accent"], linewidth=1.4, | |
| linestyle="--", alpha=0.85, label=f"threshold = {threshold:.2f}") | |
| # word labels | |
| ax.set_yticks(range(n)) | |
| ax.set_yticklabels(words, fontsize=10, color=PALETTE["text"], | |
| fontfamily="monospace") | |
| ax.invert_yaxis() | |
| # value annotations | |
| for i, (bar, v) in enumerate(zip(bars, vals)): | |
| marker = "▶" if v >= threshold else "" | |
| ax.text(min(v + 0.02, 1.05), i, f"{v:.3f} {marker}", | |
| va="center", fontsize=8.5, | |
| color=PALETTE["accent"] if v >= threshold else PALETTE["muted"]) | |
| ax.set_xlim(0, 1.18) | |
| ax.set_xlabel("Normalised importance", color=PALETTE["text"], fontsize=10) | |
| ax.set_title("Word Importance · drop-one analysis", color=PALETTE["text"], | |
| fontsize=12, fontweight="bold", pad=10) | |
| ax.tick_params(colors=PALETTE["muted"], which="both") | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor(PALETTE["border"]) | |
| ax.legend(facecolor=PALETTE["panel"], edgecolor=PALETTE["border"], | |
| labelcolor=PALETTE["accent"], fontsize=9) | |
| fig.tight_layout(pad=1.2) | |
| return _fig_to_pil(fig) | |
| def sample_prompts(words, importances, n_samples=8, seed=42): | |
| """ | |
| Each word is included in a sample with probability == its importance score. | |
| Returns HTML showing N sampled prompts, with included words highlighted | |
| by their importance colour and dropped words shown as dim strikethrough. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| vals = np.array(importances, dtype=float) | |
| def imp_to_hex(v): | |
| r, g, b, _ = CMAP(float(v)) | |
| return "#{:02x}{:02x}{:02x}".format(int(r*255), int(g*255), int(b*255)) | |
| rows_html = [] | |
| for s in range(n_samples): | |
| mask = rng.random(len(words)) < vals # Bernoulli draw | |
| word_spans = [] | |
| for word, keep, v in zip(words, mask, vals): | |
| color = imp_to_hex(v) | |
| if keep: | |
| span = ( | |
| f'<span style="color:{color};font-weight:600;' | |
| f'font-family:monospace;padding:0 1px;">{word}</span>' | |
| ) | |
| else: | |
| span = ( | |
| f'<span style="color:{PALETTE["border"]};' | |
| f'text-decoration:line-through;font-family:monospace;' | |
| f'padding:0 1px;">{word}</span>' | |
| ) | |
| word_spans.append(span) | |
| kept_count = int(mask.sum()) | |
| row = ( | |
| f'<div style="margin-bottom:10px;padding:8px 12px;' | |
| f'background:{PALETTE["bg"]};border-left:3px solid {PALETTE["border"]};' | |
| f'border-radius:0 6px 6px 0;">' | |
| f'<span style="color:{PALETTE["muted"]};font-size:11px;' | |
| f'font-family:monospace;margin-right:10px;">#{s+1} ' | |
| f'({kept_count}/{len(words)})</span>' | |
| + " ".join(word_spans) | |
| + "</div>" | |
| ) | |
| rows_html.append(row) | |
| # legend | |
| legend_stops = [0.0, 0.33, 0.66, 1.0] | |
| legend_html = "".join( | |
| f'<span style="color:{imp_to_hex(v)};font-family:monospace;' | |
| f'font-size:11px;margin-right:8px;">▮ {v:.0%}</span>' | |
| for v in legend_stops | |
| ) | |
| html = ( | |
| f'<div style="background:{PALETTE["panel"]};padding:16px 20px;' | |
| f'border-radius:8px;border:1px solid {PALETTE["border"]};">' | |
| f'<div style="margin-bottom:12px;color:{PALETTE["muted"]};font-size:12px;' | |
| f'font-family:monospace;">importance colour scale: {legend_html}</div>' | |
| + "".join(rows_html) | |
| + "</div>" | |
| ) | |
| return html | |
| def build_threshold_output(words, importances, threshold): | |
| """Return highlighted HTML and plain text for above-threshold words.""" | |
| lines = [] | |
| above = [] | |
| for word, imp in zip(words, importances): | |
| if imp >= threshold: | |
| above.append(word) | |
| style = (f"background:{PALETTE['accent']}22;" | |
| f"color:{PALETTE['accent']};" | |
| "border-radius:3px;padding:1px 4px;" | |
| "font-weight:700;font-family:monospace;") | |
| else: | |
| style = f"color:{PALETTE['muted']};font-family:monospace;" | |
| lines.append(f'<span style="{style}">{word}</span>') | |
| highlighted = ( | |
| f'<div style="background:{PALETTE["panel"]};padding:16px 20px;' | |
| f'border-radius:8px;border:1px solid {PALETTE["border"]};' | |
| f'line-height:2.1;font-size:15px;">' | |
| + " ".join(lines) | |
| + "</div>" | |
| ) | |
| summary = ( | |
| f"**{len(above)} / {len(words)} words** above threshold {threshold:.2f}:\n\n" | |
| + ", ".join(f"`{w}`" for w in above) if above else | |
| "_No words exceed the threshold._" | |
| ) | |
| return highlighted, summary | |
| # ───────────────────────────────────────────── | |
| # Main inference function | |
| # ───────────────────────────────────────────── | |
| def analyse(prompt: str, threshold: float, n_samples: int): | |
| prompt = prompt.strip() | |
| if not prompt: | |
| return None, "<p>Please enter a prompt.</p>", "", "<p></p>" | |
| ie = get_evaluator() | |
| lines = [l for l in prompt.split("\n") if l.strip()] | |
| all_words, all_imps = [], [] | |
| for line in lines: | |
| result = ie.get_word_importance_chunked(line) | |
| if result is not None: | |
| imps, words = result | |
| all_words.extend(words) | |
| all_imps.extend(imps.tolist()) | |
| if not all_words: | |
| return None, "<p>Could not parse prompt.</p>", "", "<p></p>" | |
| bar_img = plot_importance_bars(all_words, all_imps, threshold) | |
| highlighted, summary = build_threshold_output(all_words, all_imps, threshold) | |
| samples_html = sample_prompts(all_words, all_imps, n_samples=n_samples) | |
| return bar_img, highlighted, summary, samples_html | |
| # ───────────────────────────────────────────── | |
| # Gradio UI | |
| # ───────────────────────────────────────────── | |
| CSS = f""" | |
| @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;600&display=swap'); | |
| body, .gradio-container {{ | |
| background: {PALETTE['bg']} !important; | |
| font-family: 'DM Sans', sans-serif !important; | |
| color: {PALETTE['text']} !important; | |
| }} | |
| .gr-panel, .gr-box, .gr-form {{ | |
| background: {PALETTE['panel']} !important; | |
| border: 1px solid {PALETTE['border']} !important; | |
| border-radius: 10px !important; | |
| }} | |
| h1, h2, h3 {{ | |
| font-family: 'Space Mono', monospace !important; | |
| color: {PALETTE['accent']} !important; | |
| letter-spacing: -0.5px !important; | |
| }} | |
| .gr-button-primary {{ | |
| background: {PALETTE['accent']} !important; | |
| color: {PALETTE['bg']} !important; | |
| font-family: 'Space Mono', monospace !important; | |
| font-weight: 700 !important; | |
| border: none !important; | |
| border-radius: 6px !important; | |
| }} | |
| .gr-button-primary:hover {{ | |
| opacity: 0.85 !important; | |
| }} | |
| label {{ | |
| color: {PALETTE['text']} !important; | |
| font-size: 13px !important; | |
| font-family: 'Space Mono', monospace !important; | |
| }} | |
| textarea, input[type=text] {{ | |
| background: {PALETTE['bg']} !important; | |
| color: {PALETTE['text']} !important; | |
| border: 1px solid {PALETTE['border']} !important; | |
| font-family: 'Space Mono', monospace !important; | |
| font-size: 13px !important; | |
| }} | |
| .markdown-text {{ | |
| color: {PALETTE['text']} !important; | |
| }} | |
| """ | |
| DESCRIPTION = """ | |
| # 🔬 Word Importance Evaluator | |
| Drop-one embedding analysis using **static-retrieval-mrl-en-v1**. | |
| Each word's importance = semantic distance introduced by omitting it. | |
| - **Bar chart** — ranked importance with threshold line | |
| - **Threshold filter** — words above cutoff highlighted | |
| - **Sampled prompts** — each word included with probability = its importance score | |
| """ | |
| with gr.Blocks(css=CSS, title="Word Importance Evaluator") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| placeholder="a majestic lion in golden hour light, photorealistic, dramatic shadows", | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| threshold_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.3, step=0.01, | |
| label="Importance threshold", | |
| ) | |
| n_samples_slider = gr.Slider( | |
| minimum=1, maximum=20, value=8, step=1, | |
| label="Number of sampled prompts", | |
| ) | |
| run_btn = gr.Button("Analyse →", variant="primary") | |
| with gr.Column(scale=1): | |
| threshold_html = gr.HTML(label="Threshold output") | |
| threshold_md = gr.Markdown(label="Summary") | |
| bar_img = gr.Image(label="Importance bar chart", type="pil") | |
| gr.Markdown("### 🎲 Sampled prompts *(each word kept with p = importance)*") | |
| samples_html = gr.HTML(label="Sampled prompts") | |
| run_btn.click( | |
| fn=analyse, | |
| inputs=[prompt_box, threshold_slider, n_samples_slider], | |
| outputs=[bar_img, threshold_html, threshold_md, samples_html], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["a majestic lion in golden hour light, photorealistic, dramatic shadows", 0.3, 8], | |
| ["cinematic portrait of a young woman, soft bokeh, rim lighting, film grain", 0.25, 8], | |
| ["hyperrealistic macro photograph of a dewdrop on a spider web at dawn", 0.35, 10], | |
| ["oil painting of a medieval castle surrounded by autumn forest", 0.3, 8], | |
| ], | |
| inputs=[prompt_box, threshold_slider, n_samples_slider], | |
| fn=analyse, | |
| outputs=[bar_img, threshold_html, threshold_md, samples_html], | |
| cache_examples=False, | |
| ) | |
| demo.launch() | |