import gradio as gr import re import json import random import os from collections import defaultdict, Counter from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import plotly.graph_objects as go from sklearn.decomposition import PCA # --------------------------------------------------------------------------- # Global Configuration (mutable via UI) # --------------------------------------------------------------------------- CONFIG = { "model_id": "Qwen/Qwen2.5-0.5B-Instruct", "device": "cuda" if torch.cuda.is_available() else "cpu", "max_new_tokens": 120, "temperature": 0.25, "top_p": 0.9, "attn_implementation": "eager", "window": 7, "threshold": 0.3, "max_explanations": 25, "dataset_paragraphs": 3, "dataset_min_words": 20, "dataset_max_tries": 1000, } _model_singleton = None _tokenizer_singleton = None def get_model(force_reload=False): global _model_singleton, _tokenizer_singleton if _model_singleton is None or force_reload: print(f"[DEBUG] Loading {CONFIG['model_id']} on {CONFIG['device']} ...", flush=True) _tokenizer_singleton = AutoTokenizer.from_pretrained(CONFIG["model_id"]) _model_singleton = AutoModelForCausalLM.from_pretrained( CONFIG["model_id"], device_map=None, attn_implementation=CONFIG["attn_implementation"], ).to(CONFIG["device"]) _model_singleton.eval() print("[DEBUG] Model loaded.", flush=True) return _tokenizer_singleton, _model_singleton def tokenizer(): return get_model()[0] def model(): return get_model()[1] # --------------------------------------------------------------------------- # Dataset loader # --------------------------------------------------------------------------- def load_hf_text(dataset_name, config_name, split="train"): try: ds = load_dataset(dataset_name, config_name, split=split) except Exception as e: try: ds = load_dataset(dataset_name, split=split) except Exception: return f"Error loading dataset: {e}" candidates = [] for i, row in enumerate(ds): if i >= CONFIG["dataset_max_tries"]: break text = row.get("text", "") if not text or not text.strip(): continue cleaned = clean_text(text) tokens = cleaned.split() if len(tokens) >= CONFIG["dataset_min_words"]: candidates.append(cleaned) if not candidates: return "No valid paragraphs found in dataset." n = min(CONFIG["dataset_paragraphs"], len(candidates)) return "\n\n---\n\n".join(random.sample(candidates, n)) # --------------------------------------------------------------------------- # Preprocessing # --------------------------------------------------------------------------- CONTRACTIONS = { "n't": " not", "'re": " are", "'s": " is", "'d": " would", "'ll": " will", "'ve": " have", "'m": " am", "can't": "cannot", "won't": "will not", "let's": "let us", "that's": "that is", "who's": "who is", "what's": "what is", "it's": "it is", "they're": "they are", "we're": "we are", "i'm": "i am", "isn't": "is not", "aren't": "are not", "wasn't": "was not", "haven't": "have not", "hasn't": "has not", "don't": "do not", "doesn't": "does not", "didn't": "did not", "wouldn't": "would not", "couldn't": "could not", "shouldn't": "should not", "wasn't": "was not", "weren't": "were not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not", "won't": "will not", "wouldn't": "would not", "can't": "cannot", "cannot": "can not", "i'd": "i would", "you'd": "you would", "he'd": "he would", "she'd": "she would", "it'd": "it would", "we'd": "we would", "they'd": "they would", "i'll": "i will", "you'll": "you will", "he'll": "he will", "she'll": "she will", "it'll": "it will", "we'll": "we will", "they'll": "they will", "i've": "i have", "you've": "you have", "we've": "we have", "they've": "they have", "aren't": "are not", "isn't": "is not", "ain't": "am not", "let's": "let us", "that's": "that is", "who's": "who is", "what's": "what is", "here's": "here is", "there's": "there is", "where's": "where is", "how's": "how is", "it's": "it is", "she's": "she is", "he's": "he is", "that's": "that is", "there's": "there is", "what's": "what is", "let's": "let us", "who's": "who is", } def expand_contractions(text): for key in sorted(CONTRACTIONS.keys(), key=len, reverse=True): text = re.sub(r"\b" + re.escape(key) + r"\b", CONTRACTIONS[key], text, flags=re.IGNORECASE) return text def clean_text(text): text = text.lower() text = expand_contractions(text) text = re.sub(r"(\w)-(\w)", r"\1\2", text) text = re.sub(r"[^\w\s]", " ", text) text = text.replace("", "-") text = re.sub(r"\s+", " ", text).strip() return text # --------------------------------------------------------------------------- # Real attention extraction # --------------------------------------------------------------------------- def extract_attention_vectors(text, window_words=None): if window_words is None: window_words = CONFIG["window"] tok = tokenizer() mdl = model() cleaned = clean_text(text) words = cleaned.split() if len(words) < 5: return None, None, None, None, None, None, "Text too short after cleaning (need >=5 words).", None encoding = tok(words, is_split_into_words=True, return_tensors="pt", add_special_tokens=False) input_ids = encoding["input_ids"].to(CONFIG["device"]) word_ids = encoding.word_ids() with torch.no_grad(): outputs = mdl(input_ids, output_attentions=True) attn = torch.stack(outputs.attentions).mean(dim=0).mean(dim=1).squeeze(0).float() T = attn.shape[0] token_positions_by_word = [[] for _ in range(len(words))] for t, wid in enumerate(word_ids): if wid is not None and 0 <= wid < len(words): token_positions_by_word[wid].append(t) vectors = [] contexts = [] for w_idx in range(len(words)): tok_pos = token_positions_by_word[w_idx] if not tok_pos: tok_pos = [min(w_idx, T - 1)] v = attn[tok_pos, :].mean(dim=0) ctx_start = max(0, w_idx - window_words) ctx_end = min(len(words), w_idx + window_words + 1) mask = torch.zeros(T, device=CONFIG["device"]) for t, wid in enumerate(word_ids): if wid is not None and ctx_start <= wid < ctx_end: mask[t] = 1.0 v_local = v * mask norm = v_local.norm() if norm > 1e-8: v_local = v_local / norm vectors.append(v_local.cpu()) contexts.append(" ".join(words[ctx_start:ctx_end])) # Compute stats vocab = Counter(words) stats = { "word_count": len(words), "unique_words": len(vocab), "token_count": T, "top_words": vocab.most_common(10), } return words, vectors, contexts, token_positions_by_word, attn, word_ids, None, stats def cosine_similarity(v1, v2): return float((torch.dot(v1, v2) / (v1.norm() * v2.norm() + 1e-8)).item()) # --------------------------------------------------------------------------- # Pair generation # --------------------------------------------------------------------------- def generate_all_pairs(words, vectors, contexts, window=None): if window is None: window = CONFIG["window"] pairs = [] word_occurrences = defaultdict(list) for i, w in enumerate(words): word_occurrences[w].append(i) vocab = sorted(set(words)) for w in vocab: occs = word_occurrences[w] for i in range(len(occs)): for j in range(i + 1, len(occs)): idx1, idx2 = occs[i], occs[j] sim = cosine_similarity(vectors[idx1], vectors[idx2]) pairs.append({ "word": w, "neighbor": w, "similarity_score": sim, "occurrence_indices": [idx1, idx2], "context_sentences": [contexts[idx1], contexts[idx2]], "pair_type": "self" }) for w in vocab: occs = word_occurrences[w] for idx in occs: ctx_start = max(0, idx - window) ctx_end = min(len(words), idx + window + 1) for n_idx in range(ctx_start, ctx_end): if n_idx == idx: continue n = words[n_idx] sim = cosine_similarity(vectors[idx], vectors[n_idx]) pairs.append({ "word": w, "neighbor": n, "similarity_score": sim, "occurrence_indices": [idx], "context_sentences": [contexts[idx]], "pair_type": "neighbor", "distance": n_idx - idx }) return pairs def threshold_filter(pairs, threshold=None): if threshold is None: threshold = CONFIG["threshold"] result = defaultdict(list) for p in pairs: if p["similarity_score"] > threshold: result[p["word"]].append({ "neighbor": p["neighbor"], "similarity_score": p["similarity_score"], "occurrence_indices": p["occurrence_indices"], "context_sentences": p["context_sentences"], "pair_type": p.get("pair_type", "unknown"), "distance": p.get("distance", None), "semantic_explanation": "" }) return dict(result) # --------------------------------------------------------------------------- # Auto LLM Explanation # --------------------------------------------------------------------------- SYSTEM_PROMPT = ( "You are a precise semantic linguist. Given a word pair and their local context, " "explain in ONE or TWO concise sentences how these two words semantically relate or influence each other. " "Focus on semantic influence, topical association, or conceptual dependency: how the presence or meaning of " "one word affects the choice or interpretation of the other in this specific context. " "Do NOT focus on grammar, syntax, or word order. Be concrete and specific." ) def build_llm_prompt(word, neighbor, sentences, pair_type): ctx_block = "\n".join(f" Context {i+1}: {s}" for i, s in enumerate(sentences)) if pair_type == "self": user = ( f"Word pair: '{word}' (two different occurrences of the same word)\n" f"{ctx_block}\n" f"Question: How do these two occurrences of '{word}' semantically relate or influence each other? " f"Do they reinforce the same meaning, or does context shift their semantic role?" ) else: user = ( f"Word pair: '{word}' and '{neighbor}'\n" f"{ctx_block}\n" f"Question: How does '{word}' semantically influence '{neighbor}' (or vice versa) in this context? " f"What shared topic, function, or conceptual dependency binds them?" ) return [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}] def llm_explain(word, neighbor, sentences, pair_type): tok, mdl = get_model() try: messages = build_llm_prompt(word, neighbor, sentences, pair_type) prompt_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) prompt_inputs = tok(prompt_text, return_tensors="pt") prompt_len = prompt_inputs["input_ids"].shape[1] gen_inputs = tok(prompt_text, return_tensors="pt").to(CONFIG["device"]) with torch.no_grad(): outputs = mdl.generate( **gen_inputs, max_new_tokens=CONFIG["max_new_tokens"], temperature=CONFIG["temperature"], top_p=CONFIG["top_p"], do_sample=True, pad_token_id=tok.eos_token_id, ) total_len = outputs.shape[1] if total_len <= prompt_len: return "[No generation]" generated_ids = outputs[0][prompt_len:] clean_text = tok.decode(generated_ids, skip_special_tokens=True).strip() return clean_text if clean_text else "[Empty after strip]" except Exception as e: return f"[Error: {e}]" def auto_explain(filtered_dict, max_explanations=None): if max_explanations is None: max_explanations = CONFIG["max_explanations"] all_recs = [] for word, recs in filtered_dict.items(): for rec in recs: all_recs.append((word, rec)) all_recs.sort(key=lambda x: x[1]["similarity_score"], reverse=True) print(f"[DEBUG] auto_explain: {len(all_recs)} total, generating {min(max_explanations, len(all_recs))}", flush=True) for idx, (word, rec) in enumerate(all_recs[:max_explanations]): print(f"[DEBUG] Explaining {idx+1}: '{word}' / '{rec['neighbor']}' score={rec['similarity_score']:.3f}", flush=True) expl = llm_explain(word, rec["neighbor"], rec["context_sentences"], rec["pair_type"]) rec["semantic_explanation"] = expl print(f"[DEBUG] Explanation length: {len(expl)} chars", flush=True) time.sleep(0.02) return filtered_dict # --------------------------------------------------------------------------- # Visualisations # --------------------------------------------------------------------------- def plot_attention_heatmap(attn_matrix, words, word_ids, max_tok=50): attn_np = attn_matrix[:max_tok, :max_tok].cpu().numpy() fig, ax = plt.subplots(figsize=(10, 8)) im = ax.imshow(attn_np, cmap="viridis", aspect="auto") ax.set_title("Attention Heatmap (first 50 tokens)") ax.set_xlabel("Key token position") ax.set_ylabel("Query token position") fig.colorbar(im, ax=ax) plt.tight_layout() return fig def plot_attention_signature(attn_matrix, tok_positions, words, word_ids, title_word): if not tok_positions: fig, ax = plt.subplots() ax.text(0.5, 0.5, f"No tokens found for '{title_word}'", ha="center", va="center") return fig vec = attn_matrix[tok_positions, :].mean(dim=0).cpu().numpy() T = len(vec) fig, ax = plt.subplots(figsize=(14, 4)) ax.bar(range(T), vec, width=1.0, color="steelblue") ax.set_title(f"Attention Signature: '{title_word}' -> all tokens") ax.set_xlabel("Token position") ax.set_ylabel("Attention weight") step = max(1, T // 20) ax.set_xticks(range(0, T, step)) ax.set_xticklabels([str(i) for i in range(0, T, step)], rotation=45) plt.tight_layout() return fig def plot_word_frequency(words): vocab = Counter(words) top20 = vocab.most_common(20) wds, counts = zip(*top20) if top20 else ([], []) fig, ax = plt.subplots(figsize=(12, 5)) ax.barh(range(len(wds)), counts[::-1], color="teal") ax.set_yticks(range(len(wds))) ax.set_yticklabels(wds[::-1]) ax.set_xlabel("Frequency") ax.set_title("Top 20 Word Frequencies") plt.tight_layout() return fig def plot_3d_semantic_space(words, vectors, filtered_dict): if not vectors: return go.Figure() X = torch.stack(vectors).float().numpy() pca = PCA(n_components=3) coords = pca.fit_transform(X) hover_texts = [] for i, w in enumerate(words): ctx = " ".join(words[max(0, i-3):min(len(words), i+4)]) hover_texts.append(f"{w}
idx={i}
ctx: {ctx[:60]}...") filtered_words = set() for w, recs in filtered_dict.items(): for rec in recs: filtered_words.add(w) filtered_words.add(rec["neighbor"]) colors = ["#e74c3c" if w in filtered_words else "#3498db" for w in words] sizes = [8 if w in filtered_words else 4 for w in words] fig = go.Figure(data=[go.Scatter3d( x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode="markers", marker=dict(size=sizes, color=colors, opacity=0.8), text=words, hovertext=hover_texts, hoverinfo="text", )]) fig.update_layout( title="3D Semantic Space (PCA of attention signatures)
Red/large = words in retained pairs", scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3"), width=800, height=600, ) return fig def plot_pair_similarity_3d(filtered_dict, vectors, words): if not filtered_dict: return go.Figure() all_recs = [] for w, recs in filtered_dict.items(): for rec in recs: all_recs.append((w, rec)) all_recs.sort(key=lambda x: x[1]["similarity_score"], reverse=True) all_recs = all_recs[:100] if not all_recs: return go.Figure() X = torch.stack(vectors).float().numpy() pca = PCA(n_components=3) pca.fit(X) midpoints = [] scores = [] labels = [] for w, rec in all_recs: idxs = rec["occurrence_indices"] if len(idxs) >= 2: v1 = pca.transform(vectors[idxs[0]].unsqueeze(0).float().numpy())[0] v2 = pca.transform(vectors[idxs[1]].unsqueeze(0).float().numpy())[0] else: n = rec["neighbor"] n_idx = None for i, word in enumerate(words): if word == n: n_idx = i break if n_idx is None: n_idx = idxs[0] v1 = pca.transform(vectors[idxs[0]].unsqueeze(0).float().numpy())[0] v2 = pca.transform(vectors[n_idx].unsqueeze(0).float().numpy())[0] mid = (v1 + v2) / 2 midpoints.append(mid) scores.append(rec["similarity_score"]) labels.append(f"{w}–{rec['neighbor']}
score={rec['similarity_score']:.3f}") midpoints = np.array(midpoints) fig = go.Figure(data=[go.Scatter3d( x=midpoints[:, 0], y=midpoints[:, 1], z=midpoints[:, 2], mode="markers", marker=dict( size=[max(3, s * 15) for s in scores], color=scores, colorscale="Plasma", colorbar=dict(title="Similarity"), opacity=0.85, ), text=labels, hoverinfo="text", )]) fig.update_layout( title="3D Pair Similarity Space (PCA midpoints, top 100 pairs)", scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3"), width=800, height=600, ) return fig # --------------------------------------------------------------------------- # Download helpers # --------------------------------------------------------------------------- def json_download(filtered_dict): return json.dumps(filtered_dict, indent=2, ensure_ascii=False) def csv_download(filtered_dict): lines = ["word,neighbor,similarity_score,pair_type,distance,context,semantic_explanation"] for word, recs in filtered_dict.items(): for rec in recs: ctx = rec["context_sentences"][0].replace('"', '""') if rec["context_sentences"] else "" expl = rec.get("semantic_explanation", "").replace('"', '""') lines.append( f'"{word}","{rec["neighbor"]}",{rec["similarity_score"]:.6f},' f'"{rec["pair_type"]}",{rec.get("distance", "")},' f'"{ctx}","{expl}"' ) return "\n".join(lines) # --------------------------------------------------------------------------- # Gradio handlers # --------------------------------------------------------------------------- DEFAULT_TEXT = ( "Senjō no Valkyria 3: Unrecorded Chronicles (Japanese: 戦場のヴァルキュリア3) " "is a tactical role-playing video game developed by Sega and Media.Vision for the PlayStation Portable. " "Released in January 2011 in Japan, it is the third game in the Valkyria Chronicles series. " "The game uses the same fusion of tactical and real-time action as its predecessors, " "and introduces new characters and a darker storyline about a penal military unit." ) def run_pipeline(text, threshold, window): print("[DEBUG] run_pipeline started", flush=True) CONFIG["threshold"] = threshold CONFIG["window"] = int(window) result = extract_attention_vectors(text, window) words, vectors, contexts, tok_positions, attn_matrix, word_ids, err, stats = result if err: return err, [], "", None, None, None, None, None, None, None, None, None all_pairs = generate_all_pairs(words, vectors, contexts, window=window) filtered = threshold_filter(all_pairs, threshold=threshold) total = sum(len(v) for v in filtered.values()) stats_str = ( f"Words: {stats['word_count']} | Unique: {stats['unique_words']} | Tokens: {stats['token_count']} | " f"Raw pairs: {len(all_pairs)} | Retained: {total} | " f"Top words: {', '.join(f'{w}({c})' for w, c in stats['top_words'][:5])}" ) print(f"[DEBUG] {stats_str}", flush=True) rows = [] for w, recs in filtered.items(): for r in recs[:20]: rows.append([ w, r["neighbor"], f"{r['similarity_score']:.3f}", r["pair_type"], r.get("distance", ""), r["context_sentences"][0][:70] + "...", "(click Generate Explanations)", ]) heatmap_fig = plot_attention_heatmap(attn_matrix, words, word_ids) freq_fig = plot_word_frequency(words) sig_fig = plot_attention_signature( attn_matrix, tok_positions[0] if tok_positions else [], words, word_ids, words[0] if words else "" ) space3d_fig = plot_3d_semantic_space(words, vectors, filtered) pair3d_fig = plot_pair_similarity_3d(filtered, vectors, words) cache_json = json.dumps({ "words": words, "contexts": contexts, "filtered": filtered, "stats": stats, }, indent=2, ensure_ascii=False) return ( stats_str, rows, cache_json, heatmap_fig, freq_fig, sig_fig, space3d_fig, pair3d_fig, words, tok_positions, attn_matrix, word_ids ) def generate_explanations(cache_json, max_expl): if not cache_json or cache_json.strip() == "": return "Run pipeline first.", [] try: cache = json.loads(cache_json) except Exception: return "Invalid cache.", [] filtered = cache["filtered"] max_expl = int(max_expl) CONFIG["max_explanations"] = max_expl print(f"[DEBUG] generate_explanations called, max={max_expl}", flush=True) filtered = auto_explain(filtered, max_explanations=max_expl) rows = [] for w, recs in filtered.items(): for r in recs[:20]: expl = r.get("semantic_explanation", "") rows.append([ w, r["neighbor"], f"{r['similarity_score']:.3f}", r["pair_type"], r.get("distance", ""), r["context_sentences"][0][:70] + "...", expl, ]) return "Explanations generated.", rows def update_signature(cache_json, word_input, words_state, tok_positions_state, attn_matrix_state, word_ids_state): if not cache_json: return None try: cache = json.loads(cache_json) words = cache.get("words", words_state or []) tok_positions = tok_positions_state or [] attn_matrix = attn_matrix_state word_ids = word_ids_state or [] except Exception: words = words_state or [] tok_positions = tok_positions_state or [] attn_matrix = attn_matrix_state word_ids = word_ids_state or [] if not words or attn_matrix is None: return None word_input = word_input.lower().strip() if word_input not in words: return None idx = words.index(word_input) tp = tok_positions[idx] if idx < len(tok_positions) else [] return plot_attention_signature(attn_matrix, tp, words, word_ids, word_input) def fetch_dataset(dataset_name, config_name, split): result = load_hf_text(dataset_name, config_name or None, split) return result if not result.startswith("Error") else result def apply_settings(model_id, device, max_tokens, temperature, top_p, attn_impl, ds_paras, ds_min, ds_max): old_model = CONFIG["model_id"] CONFIG["model_id"] = model_id.strip() or old_model CONFIG["device"] = device CONFIG["max_new_tokens"] = int(max_tokens) CONFIG["temperature"] = float(temperature) CONFIG["top_p"] = float(top_p) CONFIG["attn_implementation"] = attn_impl CONFIG["dataset_paragraphs"] = int(ds_paras) CONFIG["dataset_min_words"] = int(ds_min) CONFIG["dataset_max_tries"] = int(ds_max) force_reload = (old_model != CONFIG["model_id"]) if force_reload: global _model_singleton, _tokenizer_singleton _model_singleton = None _tokenizer_singleton = None try: get_model(force_reload=True) return f"Settings applied. Model '{CONFIG['model_id']}' loaded on {CONFIG['device']}." except Exception as e: return f"Error loading model: {e}" return ( f"Settings applied. Model: {CONFIG['model_id']} | Device: {CONFIG['device']} | " f"Max tokens: {CONFIG['max_new_tokens']} | Temp: {CONFIG['temperature']} | Top-p: {CONFIG['top_p']}" ) def download_json(cache_json): if not cache_json: return None try: cache = json.loads(cache_json) filtered = cache.get("filtered", {}) return json_download(filtered) except Exception: return None def download_csv(cache_json): if not cache_json: return None try: cache = json.loads(cache_json) filtered = cache.get("filtered", {}) return csv_download(filtered) except Exception: return None # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks(title="Semantic Attention Explorer", css=""" .dataframe-wrap { white-space: pre-wrap !important; } """) as demo: gr.Markdown("# 🔍 Semantic Attention Explorer") gr.Markdown( "Extract **real neural attention** from a causal LM, compute **cosine similarity** between " "centered attention signatures, and auto-generate **LLM semantic explanations** for retained pairs." ) with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox(label="Input Text", lines=8, value=DEFAULT_TEXT) with gr.Column(scale=1): thresh_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Similarity Threshold") win_slider = gr.Slider(1, 15, value=7, step=1, label="Context Window (words)") run_btn = gr.Button("▶️ Run Pipeline", variant="primary") with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): model_input = gr.Textbox( label="Model ID", value=CONFIG["model_id"], placeholder="e.g. Qwen/Qwen2.5-0.5B-Instruct", ) device_dd = gr.Dropdown( ["cpu", "cuda"] + (["mps"] if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else []), value=CONFIG["device"], label="Device", ) attn_impl_dd = gr.Dropdown(["eager", "sdpa", "flash_attention_2"], value=CONFIG["attn_implementation"], label="Attention Implementation") with gr.Row(): max_tokens_num = gr.Number(value=CONFIG["max_new_tokens"], label="Max New Tokens", precision=0, minimum=1, maximum=2048) temp_slider = gr.Slider(0.0, 2.0, value=CONFIG["temperature"], step=0.05, label="Temperature") top_p_slider = gr.Slider(0.0, 1.0, value=CONFIG["top_p"], step=0.05, label="Top-p") with gr.Row(): ds_paras_num = gr.Number(value=CONFIG["dataset_paragraphs"], label="Dataset paragraphs", precision=0, minimum=1, maximum=100) ds_min_num = gr.Number(value=CONFIG["dataset_min_words"], label="Min words per paragraph", precision=0, minimum=1, maximum=500) ds_max_num = gr.Number(value=CONFIG["dataset_max_tries"], label="Dataset scan limit", precision=0, minimum=10, maximum=100000) apply_btn = gr.Button("Apply Settings & Reload Model", variant="secondary") settings_status = gr.Textbox(label="Settings Status", interactive=False) with gr.Accordion("📚 Load from HuggingFace Dataset", open=False): with gr.Row(): ds_name = gr.Textbox(label="Dataset name", value="wikitext", placeholder="e.g. wikitext") ds_config = gr.Textbox(label="Config name (optional)", value="wikitext-2-raw-v1", placeholder="e.g. wikitext-2-raw-v1") ds_split = gr.Dropdown(["train", "validation", "test"], value="train", label="Split") ds_btn = gr.Button("Load Sample Paragraphs") ds_output = gr.Textbox(label="Loaded Paragraphs (copy one into Input Text)", lines=6, interactive=False) ds_btn.click(fetch_dataset, inputs=[ds_name, ds_config, ds_split], outputs=[ds_output]) apply_btn.click( apply_settings, inputs=[model_input, device_dd, max_tokens_num, temp_slider, top_p_slider, attn_impl_dd, ds_paras_num, ds_min_num, ds_max_num], outputs=[settings_status], ) # Hidden cache states cache_state = gr.State() words_state = gr.State() tok_positions_state = gr.State() attn_matrix_state = gr.State() word_ids_state = gr.State() summary_box = gr.Textbox(label="Pipeline Summary", interactive=False) with gr.Row(): with gr.Column(scale=1): max_expl_num = gr.Number( value=25, label="Auto LLM Explanations (top-N)", precision=0, minimum=0, maximum=9999, ) explain_btn = gr.Button("🤖 Generate Explanations", variant="secondary") with gr.Row(): json_dl_btn = gr.Button("📥 Download JSON") csv_dl_btn = gr.Button("📥 Download CSV") json_file = gr.File(label="JSON Download", visible=False) csv_file = gr.File(label="CSV Download", visible=False) with gr.Column(scale=2): expl_status = gr.Textbox(label="Explanation Status", interactive=False) pairs_df = gr.Dataframe( headers=["Word", "Neighbor", "Score", "Type", "Distance", "Context", "LLM Explanation"], interactive=False, wrap=True, ) with gr.Tab("🌡️ Attention Heatmap"): heatmap_plot = gr.Plot(label="Full Attention Matrix (first 50 tokens)") with gr.Tab("📊 Word Frequency"): freq_plot = gr.Plot(label="Top 20 Word Frequencies") with gr.Tab("📈 Attention Signature"): with gr.Row(): sig_word_input = gr.Textbox(label="Word to visualize", placeholder="e.g. game") sig_update_btn = gr.Button("Update Signature") sig_plot = gr.Plot(label="Attention Signature (word -> all tokens)") with gr.Tab("🌌 3D Semantic Space"): space3d_plot = gr.Plot(label="3D PCA of attention signatures") with gr.Tab("🔗 3D Pair Space"): pair3d_plot = gr.Plot(label="3D PCA of pair midpoints") explain_btn.click( generate_explanations, inputs=[cache_state, max_expl_num], outputs=[expl_status, pairs_df] ) run_btn.click( run_pipeline, inputs=[input_text, thresh_slider, win_slider], outputs=[ summary_box, pairs_df, cache_state, heatmap_plot, freq_plot, sig_plot, space3d_plot, pair3d_plot, words_state, tok_positions_state, attn_matrix_state, word_ids_state, ] ) sig_update_btn.click( update_signature, inputs=[cache_state, sig_word_input, words_state, tok_positions_state, attn_matrix_state, word_ids_state], outputs=[sig_plot] ) # Download handlers def _save_json(text): if not text: return None path = "/tmp/results.json" with open(path, "w", encoding="utf-8") as f: f.write(text) return path def _save_csv(text): if not text: return None path = "/tmp/results.csv" with open(path, "w", encoding="utf-8") as f: f.write(text) return path json_dl_btn.click(download_json, inputs=[cache_state], outputs=[json_file]).then( _save_json, inputs=[json_file], outputs=[json_file] ) csv_dl_btn.click(download_csv, inputs=[cache_state], outputs=[csv_file]).then( _save_csv, inputs=[csv_file], outputs=[csv_file] ) demo.launch()