Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import re | |
| import json | |
| import random | |
| import os | |
| from collections import defaultdict, Counter | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import time | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| from sklearn.decomposition import PCA | |
| # --------------------------------------------------------------------------- | |
| # Global Configuration (mutable via UI) | |
| # --------------------------------------------------------------------------- | |
| CONFIG = { | |
| "model_id": "Qwen/Qwen2.5-0.5B-Instruct", | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| "max_new_tokens": 120, | |
| "temperature": 0.25, | |
| "top_p": 0.9, | |
| "attn_implementation": "eager", | |
| "window": 7, | |
| "threshold": 0.3, | |
| "max_explanations": 25, | |
| "dataset_paragraphs": 3, | |
| "dataset_min_words": 20, | |
| "dataset_max_tries": 1000, | |
| } | |
| _model_singleton = None | |
| _tokenizer_singleton = None | |
| def get_model(force_reload=False): | |
| global _model_singleton, _tokenizer_singleton | |
| if _model_singleton is None or force_reload: | |
| print(f"[DEBUG] Loading {CONFIG['model_id']} on {CONFIG['device']} ...", flush=True) | |
| _tokenizer_singleton = AutoTokenizer.from_pretrained(CONFIG["model_id"]) | |
| _model_singleton = AutoModelForCausalLM.from_pretrained( | |
| CONFIG["model_id"], | |
| device_map=None, | |
| attn_implementation=CONFIG["attn_implementation"], | |
| ).to(CONFIG["device"]) | |
| _model_singleton.eval() | |
| print("[DEBUG] Model loaded.", flush=True) | |
| return _tokenizer_singleton, _model_singleton | |
| def tokenizer(): | |
| return get_model()[0] | |
| def model(): | |
| return get_model()[1] | |
| # --------------------------------------------------------------------------- | |
| # Dataset loader | |
| # --------------------------------------------------------------------------- | |
| def load_hf_text(dataset_name, config_name, split="train"): | |
| try: | |
| ds = load_dataset(dataset_name, config_name, split=split) | |
| except Exception as e: | |
| try: | |
| ds = load_dataset(dataset_name, split=split) | |
| except Exception: | |
| return f"Error loading dataset: {e}" | |
| candidates = [] | |
| for i, row in enumerate(ds): | |
| if i >= CONFIG["dataset_max_tries"]: | |
| break | |
| text = row.get("text", "") | |
| if not text or not text.strip(): | |
| continue | |
| cleaned = clean_text(text) | |
| tokens = cleaned.split() | |
| if len(tokens) >= CONFIG["dataset_min_words"]: | |
| candidates.append(cleaned) | |
| if not candidates: | |
| return "No valid paragraphs found in dataset." | |
| n = min(CONFIG["dataset_paragraphs"], len(candidates)) | |
| return "\n\n---\n\n".join(random.sample(candidates, n)) | |
| # --------------------------------------------------------------------------- | |
| # Preprocessing | |
| # --------------------------------------------------------------------------- | |
| CONTRACTIONS = { | |
| "n't": " not", "'re": " are", "'s": " is", "'d": " would", | |
| "'ll": " will", "'ve": " have", "'m": " am", "can't": "cannot", | |
| "won't": "will not", "let's": "let us", "that's": "that is", | |
| "who's": "who is", "what's": "what is", "it's": "it is", | |
| "they're": "they are", "we're": "we are", "i'm": "i am", | |
| "isn't": "is not", "aren't": "are not", "wasn't": "was not", | |
| "haven't": "have not", "hasn't": "has not", "don't": "do not", | |
| "doesn't": "does not", "didn't": "did not", "wouldn't": "would not", | |
| "couldn't": "could not", "shouldn't": "should not", "wasn't": "was not", | |
| "weren't": "were not", "hadn't": "had not", "hasn't": "has not", | |
| "haven't": "have not", "won't": "will not", "wouldn't": "would not", | |
| "can't": "cannot", "cannot": "can not", "i'd": "i would", | |
| "you'd": "you would", "he'd": "he would", "she'd": "she would", | |
| "it'd": "it would", "we'd": "we would", "they'd": "they would", | |
| "i'll": "i will", "you'll": "you will", "he'll": "he will", | |
| "she'll": "she will", "it'll": "it will", "we'll": "we will", | |
| "they'll": "they will", "i've": "i have", "you've": "you have", | |
| "we've": "we have", "they've": "they have", "aren't": "are not", | |
| "isn't": "is not", "ain't": "am not", "let's": "let us", | |
| "that's": "that is", "who's": "who is", "what's": "what is", | |
| "here's": "here is", "there's": "there is", "where's": "where is", | |
| "how's": "how is", "it's": "it is", "she's": "she is", | |
| "he's": "he is", "that's": "that is", "there's": "there is", | |
| "what's": "what is", "let's": "let us", "who's": "who is", | |
| } | |
| def expand_contractions(text): | |
| for key in sorted(CONTRACTIONS.keys(), key=len, reverse=True): | |
| text = re.sub(r"\b" + re.escape(key) + r"\b", CONTRACTIONS[key], text, flags=re.IGNORECASE) | |
| return text | |
| def clean_text(text): | |
| text = text.lower() | |
| text = expand_contractions(text) | |
| text = re.sub(r"(\w)-(\w)", r"\1<HYPHEN>\2", text) | |
| text = re.sub(r"[^\w\s<HYPHEN>]", " ", text) | |
| text = text.replace("<HYPHEN>", "-") | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| # --------------------------------------------------------------------------- | |
| # Real attention extraction | |
| # --------------------------------------------------------------------------- | |
| def extract_attention_vectors(text, window_words=None): | |
| if window_words is None: | |
| window_words = CONFIG["window"] | |
| tok = tokenizer() | |
| mdl = model() | |
| cleaned = clean_text(text) | |
| words = cleaned.split() | |
| if len(words) < 5: | |
| return None, None, None, None, None, None, "Text too short after cleaning (need >=5 words).", None | |
| encoding = tok(words, is_split_into_words=True, return_tensors="pt", add_special_tokens=False) | |
| input_ids = encoding["input_ids"].to(CONFIG["device"]) | |
| word_ids = encoding.word_ids() | |
| with torch.no_grad(): | |
| outputs = mdl(input_ids, output_attentions=True) | |
| attn = torch.stack(outputs.attentions).mean(dim=0).mean(dim=1).squeeze(0).float() | |
| T = attn.shape[0] | |
| token_positions_by_word = [[] for _ in range(len(words))] | |
| for t, wid in enumerate(word_ids): | |
| if wid is not None and 0 <= wid < len(words): | |
| token_positions_by_word[wid].append(t) | |
| vectors = [] | |
| contexts = [] | |
| for w_idx in range(len(words)): | |
| tok_pos = token_positions_by_word[w_idx] | |
| if not tok_pos: | |
| tok_pos = [min(w_idx, T - 1)] | |
| v = attn[tok_pos, :].mean(dim=0) | |
| ctx_start = max(0, w_idx - window_words) | |
| ctx_end = min(len(words), w_idx + window_words + 1) | |
| mask = torch.zeros(T, device=CONFIG["device"]) | |
| for t, wid in enumerate(word_ids): | |
| if wid is not None and ctx_start <= wid < ctx_end: | |
| mask[t] = 1.0 | |
| v_local = v * mask | |
| norm = v_local.norm() | |
| if norm > 1e-8: | |
| v_local = v_local / norm | |
| vectors.append(v_local.cpu()) | |
| contexts.append(" ".join(words[ctx_start:ctx_end])) | |
| # Compute stats | |
| vocab = Counter(words) | |
| stats = { | |
| "word_count": len(words), | |
| "unique_words": len(vocab), | |
| "token_count": T, | |
| "top_words": vocab.most_common(10), | |
| } | |
| return words, vectors, contexts, token_positions_by_word, attn, word_ids, None, stats | |
| def cosine_similarity(v1, v2): | |
| return float((torch.dot(v1, v2) / (v1.norm() * v2.norm() + 1e-8)).item()) | |
| # --------------------------------------------------------------------------- | |
| # Pair generation | |
| # --------------------------------------------------------------------------- | |
| def generate_all_pairs(words, vectors, contexts, window=None): | |
| if window is None: | |
| window = CONFIG["window"] | |
| pairs = [] | |
| word_occurrences = defaultdict(list) | |
| for i, w in enumerate(words): | |
| word_occurrences[w].append(i) | |
| vocab = sorted(set(words)) | |
| for w in vocab: | |
| occs = word_occurrences[w] | |
| for i in range(len(occs)): | |
| for j in range(i + 1, len(occs)): | |
| idx1, idx2 = occs[i], occs[j] | |
| sim = cosine_similarity(vectors[idx1], vectors[idx2]) | |
| pairs.append({ | |
| "word": w, "neighbor": w, "similarity_score": sim, | |
| "occurrence_indices": [idx1, idx2], | |
| "context_sentences": [contexts[idx1], contexts[idx2]], | |
| "pair_type": "self" | |
| }) | |
| for w in vocab: | |
| occs = word_occurrences[w] | |
| for idx in occs: | |
| ctx_start = max(0, idx - window) | |
| ctx_end = min(len(words), idx + window + 1) | |
| for n_idx in range(ctx_start, ctx_end): | |
| if n_idx == idx: | |
| continue | |
| n = words[n_idx] | |
| sim = cosine_similarity(vectors[idx], vectors[n_idx]) | |
| pairs.append({ | |
| "word": w, "neighbor": n, "similarity_score": sim, | |
| "occurrence_indices": [idx], | |
| "context_sentences": [contexts[idx]], | |
| "pair_type": "neighbor", "distance": n_idx - idx | |
| }) | |
| return pairs | |
| def threshold_filter(pairs, threshold=None): | |
| if threshold is None: | |
| threshold = CONFIG["threshold"] | |
| result = defaultdict(list) | |
| for p in pairs: | |
| if p["similarity_score"] > threshold: | |
| result[p["word"]].append({ | |
| "neighbor": p["neighbor"], | |
| "similarity_score": p["similarity_score"], | |
| "occurrence_indices": p["occurrence_indices"], | |
| "context_sentences": p["context_sentences"], | |
| "pair_type": p.get("pair_type", "unknown"), | |
| "distance": p.get("distance", None), | |
| "semantic_explanation": "" | |
| }) | |
| return dict(result) | |
| # --------------------------------------------------------------------------- | |
| # Auto LLM Explanation | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = ( | |
| "You are a precise semantic linguist. Given a word pair and their local context, " | |
| "explain in ONE or TWO concise sentences how these two words semantically relate or influence each other. " | |
| "Focus on semantic influence, topical association, or conceptual dependency: how the presence or meaning of " | |
| "one word affects the choice or interpretation of the other in this specific context. " | |
| "Do NOT focus on grammar, syntax, or word order. Be concrete and specific." | |
| ) | |
| def build_llm_prompt(word, neighbor, sentences, pair_type): | |
| ctx_block = "\n".join(f" Context {i+1}: {s}" for i, s in enumerate(sentences)) | |
| if pair_type == "self": | |
| user = ( | |
| f"Word pair: '{word}' (two different occurrences of the same word)\n" | |
| f"{ctx_block}\n" | |
| f"Question: How do these two occurrences of '{word}' semantically relate or influence each other? " | |
| f"Do they reinforce the same meaning, or does context shift their semantic role?" | |
| ) | |
| else: | |
| user = ( | |
| f"Word pair: '{word}' and '{neighbor}'\n" | |
| f"{ctx_block}\n" | |
| f"Question: How does '{word}' semantically influence '{neighbor}' (or vice versa) in this context? " | |
| f"What shared topic, function, or conceptual dependency binds them?" | |
| ) | |
| return [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}] | |
| def llm_explain(word, neighbor, sentences, pair_type): | |
| tok, mdl = get_model() | |
| try: | |
| messages = build_llm_prompt(word, neighbor, sentences, pair_type) | |
| prompt_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| prompt_inputs = tok(prompt_text, return_tensors="pt") | |
| prompt_len = prompt_inputs["input_ids"].shape[1] | |
| gen_inputs = tok(prompt_text, return_tensors="pt").to(CONFIG["device"]) | |
| with torch.no_grad(): | |
| outputs = mdl.generate( | |
| **gen_inputs, | |
| max_new_tokens=CONFIG["max_new_tokens"], | |
| temperature=CONFIG["temperature"], | |
| top_p=CONFIG["top_p"], | |
| do_sample=True, | |
| pad_token_id=tok.eos_token_id, | |
| ) | |
| total_len = outputs.shape[1] | |
| if total_len <= prompt_len: | |
| return "[No generation]" | |
| generated_ids = outputs[0][prompt_len:] | |
| clean_text = tok.decode(generated_ids, skip_special_tokens=True).strip() | |
| return clean_text if clean_text else "[Empty after strip]" | |
| except Exception as e: | |
| return f"[Error: {e}]" | |
| def auto_explain(filtered_dict, max_explanations=None): | |
| if max_explanations is None: | |
| max_explanations = CONFIG["max_explanations"] | |
| all_recs = [] | |
| for word, recs in filtered_dict.items(): | |
| for rec in recs: | |
| all_recs.append((word, rec)) | |
| all_recs.sort(key=lambda x: x[1]["similarity_score"], reverse=True) | |
| print(f"[DEBUG] auto_explain: {len(all_recs)} total, generating {min(max_explanations, len(all_recs))}", flush=True) | |
| for idx, (word, rec) in enumerate(all_recs[:max_explanations]): | |
| print(f"[DEBUG] Explaining {idx+1}: '{word}' / '{rec['neighbor']}' score={rec['similarity_score']:.3f}", flush=True) | |
| expl = llm_explain(word, rec["neighbor"], rec["context_sentences"], rec["pair_type"]) | |
| rec["semantic_explanation"] = expl | |
| print(f"[DEBUG] Explanation length: {len(expl)} chars", flush=True) | |
| time.sleep(0.02) | |
| return filtered_dict | |
| # --------------------------------------------------------------------------- | |
| # Visualisations | |
| # --------------------------------------------------------------------------- | |
| def plot_attention_heatmap(attn_matrix, words, word_ids, max_tok=50): | |
| attn_np = attn_matrix[:max_tok, :max_tok].cpu().numpy() | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| im = ax.imshow(attn_np, cmap="viridis", aspect="auto") | |
| ax.set_title("Attention Heatmap (first 50 tokens)") | |
| ax.set_xlabel("Key token position") | |
| ax.set_ylabel("Query token position") | |
| fig.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| return fig | |
| def plot_attention_signature(attn_matrix, tok_positions, words, word_ids, title_word): | |
| if not tok_positions: | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, f"No tokens found for '{title_word}'", ha="center", va="center") | |
| return fig | |
| vec = attn_matrix[tok_positions, :].mean(dim=0).cpu().numpy() | |
| T = len(vec) | |
| fig, ax = plt.subplots(figsize=(14, 4)) | |
| ax.bar(range(T), vec, width=1.0, color="steelblue") | |
| ax.set_title(f"Attention Signature: '{title_word}' -> all tokens") | |
| ax.set_xlabel("Token position") | |
| ax.set_ylabel("Attention weight") | |
| step = max(1, T // 20) | |
| ax.set_xticks(range(0, T, step)) | |
| ax.set_xticklabels([str(i) for i in range(0, T, step)], rotation=45) | |
| plt.tight_layout() | |
| return fig | |
| def plot_word_frequency(words): | |
| vocab = Counter(words) | |
| top20 = vocab.most_common(20) | |
| wds, counts = zip(*top20) if top20 else ([], []) | |
| fig, ax = plt.subplots(figsize=(12, 5)) | |
| ax.barh(range(len(wds)), counts[::-1], color="teal") | |
| ax.set_yticks(range(len(wds))) | |
| ax.set_yticklabels(wds[::-1]) | |
| ax.set_xlabel("Frequency") | |
| ax.set_title("Top 20 Word Frequencies") | |
| plt.tight_layout() | |
| return fig | |
| def plot_3d_semantic_space(words, vectors, filtered_dict): | |
| if not vectors: | |
| return go.Figure() | |
| X = torch.stack(vectors).float().numpy() | |
| pca = PCA(n_components=3) | |
| coords = pca.fit_transform(X) | |
| hover_texts = [] | |
| for i, w in enumerate(words): | |
| ctx = " ".join(words[max(0, i-3):min(len(words), i+4)]) | |
| hover_texts.append(f"{w}<br>idx={i}<br>ctx: {ctx[:60]}...") | |
| filtered_words = set() | |
| for w, recs in filtered_dict.items(): | |
| for rec in recs: | |
| filtered_words.add(w) | |
| filtered_words.add(rec["neighbor"]) | |
| colors = ["#e74c3c" if w in filtered_words else "#3498db" for w in words] | |
| sizes = [8 if w in filtered_words else 4 for w in words] | |
| fig = go.Figure(data=[go.Scatter3d( | |
| x=coords[:, 0], | |
| y=coords[:, 1], | |
| z=coords[:, 2], | |
| mode="markers", | |
| marker=dict(size=sizes, color=colors, opacity=0.8), | |
| text=words, | |
| hovertext=hover_texts, | |
| hoverinfo="text", | |
| )]) | |
| fig.update_layout( | |
| title="3D Semantic Space (PCA of attention signatures)<br>Red/large = words in retained pairs", | |
| scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3"), | |
| width=800, | |
| height=600, | |
| ) | |
| return fig | |
| def plot_pair_similarity_3d(filtered_dict, vectors, words): | |
| if not filtered_dict: | |
| return go.Figure() | |
| all_recs = [] | |
| for w, recs in filtered_dict.items(): | |
| for rec in recs: | |
| all_recs.append((w, rec)) | |
| all_recs.sort(key=lambda x: x[1]["similarity_score"], reverse=True) | |
| all_recs = all_recs[:100] | |
| if not all_recs: | |
| return go.Figure() | |
| X = torch.stack(vectors).float().numpy() | |
| pca = PCA(n_components=3) | |
| pca.fit(X) | |
| midpoints = [] | |
| scores = [] | |
| labels = [] | |
| for w, rec in all_recs: | |
| idxs = rec["occurrence_indices"] | |
| if len(idxs) >= 2: | |
| v1 = pca.transform(vectors[idxs[0]].unsqueeze(0).float().numpy())[0] | |
| v2 = pca.transform(vectors[idxs[1]].unsqueeze(0).float().numpy())[0] | |
| else: | |
| n = rec["neighbor"] | |
| n_idx = None | |
| for i, word in enumerate(words): | |
| if word == n: | |
| n_idx = i | |
| break | |
| if n_idx is None: | |
| n_idx = idxs[0] | |
| v1 = pca.transform(vectors[idxs[0]].unsqueeze(0).float().numpy())[0] | |
| v2 = pca.transform(vectors[n_idx].unsqueeze(0).float().numpy())[0] | |
| mid = (v1 + v2) / 2 | |
| midpoints.append(mid) | |
| scores.append(rec["similarity_score"]) | |
| labels.append(f"{w}–{rec['neighbor']}<br>score={rec['similarity_score']:.3f}") | |
| midpoints = np.array(midpoints) | |
| fig = go.Figure(data=[go.Scatter3d( | |
| x=midpoints[:, 0], | |
| y=midpoints[:, 1], | |
| z=midpoints[:, 2], | |
| mode="markers", | |
| marker=dict( | |
| size=[max(3, s * 15) for s in scores], | |
| color=scores, | |
| colorscale="Plasma", | |
| colorbar=dict(title="Similarity"), | |
| opacity=0.85, | |
| ), | |
| text=labels, | |
| hoverinfo="text", | |
| )]) | |
| fig.update_layout( | |
| title="3D Pair Similarity Space (PCA midpoints, top 100 pairs)", | |
| scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3"), | |
| width=800, | |
| height=600, | |
| ) | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # Download helpers | |
| # --------------------------------------------------------------------------- | |
| def json_download(filtered_dict): | |
| return json.dumps(filtered_dict, indent=2, ensure_ascii=False) | |
| def csv_download(filtered_dict): | |
| lines = ["word,neighbor,similarity_score,pair_type,distance,context,semantic_explanation"] | |
| for word, recs in filtered_dict.items(): | |
| for rec in recs: | |
| ctx = rec["context_sentences"][0].replace('"', '""') if rec["context_sentences"] else "" | |
| expl = rec.get("semantic_explanation", "").replace('"', '""') | |
| lines.append( | |
| f'"{word}","{rec["neighbor"]}",{rec["similarity_score"]:.6f},' | |
| f'"{rec["pair_type"]}",{rec.get("distance", "")},' | |
| f'"{ctx}","{expl}"' | |
| ) | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Gradio handlers | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_TEXT = ( | |
| "Senjō no Valkyria 3: Unrecorded Chronicles (Japanese: 戦場のヴァルキュリア3) " | |
| "is a tactical role-playing video game developed by Sega and Media.Vision for the PlayStation Portable. " | |
| "Released in January 2011 in Japan, it is the third game in the Valkyria Chronicles series. " | |
| "The game uses the same fusion of tactical and real-time action as its predecessors, " | |
| "and introduces new characters and a darker storyline about a penal military unit." | |
| ) | |
| def run_pipeline(text, threshold, window): | |
| print("[DEBUG] run_pipeline started", flush=True) | |
| CONFIG["threshold"] = threshold | |
| CONFIG["window"] = int(window) | |
| result = extract_attention_vectors(text, window) | |
| words, vectors, contexts, tok_positions, attn_matrix, word_ids, err, stats = result | |
| if err: | |
| return err, [], "", None, None, None, None, None, None, None, None, None | |
| all_pairs = generate_all_pairs(words, vectors, contexts, window=window) | |
| filtered = threshold_filter(all_pairs, threshold=threshold) | |
| total = sum(len(v) for v in filtered.values()) | |
| stats_str = ( | |
| f"Words: {stats['word_count']} | Unique: {stats['unique_words']} | Tokens: {stats['token_count']} | " | |
| f"Raw pairs: {len(all_pairs)} | Retained: {total} | " | |
| f"Top words: {', '.join(f'{w}({c})' for w, c in stats['top_words'][:5])}" | |
| ) | |
| print(f"[DEBUG] {stats_str}", flush=True) | |
| rows = [] | |
| for w, recs in filtered.items(): | |
| for r in recs[:20]: | |
| rows.append([ | |
| w, | |
| r["neighbor"], | |
| f"{r['similarity_score']:.3f}", | |
| r["pair_type"], | |
| r.get("distance", ""), | |
| r["context_sentences"][0][:70] + "...", | |
| "(click Generate Explanations)", | |
| ]) | |
| heatmap_fig = plot_attention_heatmap(attn_matrix, words, word_ids) | |
| freq_fig = plot_word_frequency(words) | |
| sig_fig = plot_attention_signature( | |
| attn_matrix, | |
| tok_positions[0] if tok_positions else [], | |
| words, word_ids, words[0] if words else "" | |
| ) | |
| space3d_fig = plot_3d_semantic_space(words, vectors, filtered) | |
| pair3d_fig = plot_pair_similarity_3d(filtered, vectors, words) | |
| cache_json = json.dumps({ | |
| "words": words, | |
| "contexts": contexts, | |
| "filtered": filtered, | |
| "stats": stats, | |
| }, indent=2, ensure_ascii=False) | |
| return ( | |
| stats_str, rows, cache_json, | |
| heatmap_fig, freq_fig, sig_fig, space3d_fig, pair3d_fig, | |
| words, tok_positions, attn_matrix, word_ids | |
| ) | |
| def generate_explanations(cache_json, max_expl): | |
| if not cache_json or cache_json.strip() == "": | |
| return "Run pipeline first.", [] | |
| try: | |
| cache = json.loads(cache_json) | |
| except Exception: | |
| return "Invalid cache.", [] | |
| filtered = cache["filtered"] | |
| max_expl = int(max_expl) | |
| CONFIG["max_explanations"] = max_expl | |
| print(f"[DEBUG] generate_explanations called, max={max_expl}", flush=True) | |
| filtered = auto_explain(filtered, max_explanations=max_expl) | |
| rows = [] | |
| for w, recs in filtered.items(): | |
| for r in recs[:20]: | |
| expl = r.get("semantic_explanation", "") | |
| rows.append([ | |
| w, | |
| r["neighbor"], | |
| f"{r['similarity_score']:.3f}", | |
| r["pair_type"], | |
| r.get("distance", ""), | |
| r["context_sentences"][0][:70] + "...", | |
| expl, | |
| ]) | |
| return "Explanations generated.", rows | |
| def update_signature(cache_json, word_input, words_state, tok_positions_state, attn_matrix_state, word_ids_state): | |
| if not cache_json: | |
| return None | |
| try: | |
| cache = json.loads(cache_json) | |
| words = cache.get("words", words_state or []) | |
| tok_positions = tok_positions_state or [] | |
| attn_matrix = attn_matrix_state | |
| word_ids = word_ids_state or [] | |
| except Exception: | |
| words = words_state or [] | |
| tok_positions = tok_positions_state or [] | |
| attn_matrix = attn_matrix_state | |
| word_ids = word_ids_state or [] | |
| if not words or attn_matrix is None: | |
| return None | |
| word_input = word_input.lower().strip() | |
| if word_input not in words: | |
| return None | |
| idx = words.index(word_input) | |
| tp = tok_positions[idx] if idx < len(tok_positions) else [] | |
| return plot_attention_signature(attn_matrix, tp, words, word_ids, word_input) | |
| def fetch_dataset(dataset_name, config_name, split): | |
| result = load_hf_text(dataset_name, config_name or None, split) | |
| return result if not result.startswith("Error") else result | |
| def apply_settings(model_id, device, max_tokens, temperature, top_p, attn_impl, ds_paras, ds_min, ds_max): | |
| old_model = CONFIG["model_id"] | |
| CONFIG["model_id"] = model_id.strip() or old_model | |
| CONFIG["device"] = device | |
| CONFIG["max_new_tokens"] = int(max_tokens) | |
| CONFIG["temperature"] = float(temperature) | |
| CONFIG["top_p"] = float(top_p) | |
| CONFIG["attn_implementation"] = attn_impl | |
| CONFIG["dataset_paragraphs"] = int(ds_paras) | |
| CONFIG["dataset_min_words"] = int(ds_min) | |
| CONFIG["dataset_max_tries"] = int(ds_max) | |
| force_reload = (old_model != CONFIG["model_id"]) | |
| if force_reload: | |
| global _model_singleton, _tokenizer_singleton | |
| _model_singleton = None | |
| _tokenizer_singleton = None | |
| try: | |
| get_model(force_reload=True) | |
| return f"Settings applied. Model '{CONFIG['model_id']}' loaded on {CONFIG['device']}." | |
| except Exception as e: | |
| return f"Error loading model: {e}" | |
| return ( | |
| f"Settings applied. Model: {CONFIG['model_id']} | Device: {CONFIG['device']} | " | |
| f"Max tokens: {CONFIG['max_new_tokens']} | Temp: {CONFIG['temperature']} | Top-p: {CONFIG['top_p']}" | |
| ) | |
| def download_json(cache_json): | |
| if not cache_json: | |
| return None | |
| try: | |
| cache = json.loads(cache_json) | |
| filtered = cache.get("filtered", {}) | |
| return json_download(filtered) | |
| except Exception: | |
| return None | |
| def download_csv(cache_json): | |
| if not cache_json: | |
| return None | |
| try: | |
| cache = json.loads(cache_json) | |
| filtered = cache.get("filtered", {}) | |
| return csv_download(filtered) | |
| except Exception: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Semantic Attention Explorer", css=""" | |
| .dataframe-wrap { white-space: pre-wrap !important; } | |
| """) as demo: | |
| gr.Markdown("# 🔍 Semantic Attention Explorer") | |
| gr.Markdown( | |
| "Extract **real neural attention** from a causal LM, compute **cosine similarity** between " | |
| "centered attention signatures, and auto-generate **LLM semantic explanations** for retained pairs." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.Textbox(label="Input Text", lines=8, value=DEFAULT_TEXT) | |
| with gr.Column(scale=1): | |
| thresh_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Similarity Threshold") | |
| win_slider = gr.Slider(1, 15, value=7, step=1, label="Context Window (words)") | |
| run_btn = gr.Button("▶️ Run Pipeline", variant="primary") | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| model_input = gr.Textbox( | |
| label="Model ID", | |
| value=CONFIG["model_id"], | |
| placeholder="e.g. Qwen/Qwen2.5-0.5B-Instruct", | |
| ) | |
| device_dd = gr.Dropdown( | |
| ["cpu", "cuda"] + (["mps"] if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else []), | |
| value=CONFIG["device"], | |
| label="Device", | |
| ) | |
| attn_impl_dd = gr.Dropdown(["eager", "sdpa", "flash_attention_2"], value=CONFIG["attn_implementation"], label="Attention Implementation") | |
| with gr.Row(): | |
| max_tokens_num = gr.Number(value=CONFIG["max_new_tokens"], label="Max New Tokens", precision=0, minimum=1, maximum=2048) | |
| temp_slider = gr.Slider(0.0, 2.0, value=CONFIG["temperature"], step=0.05, label="Temperature") | |
| top_p_slider = gr.Slider(0.0, 1.0, value=CONFIG["top_p"], step=0.05, label="Top-p") | |
| with gr.Row(): | |
| ds_paras_num = gr.Number(value=CONFIG["dataset_paragraphs"], label="Dataset paragraphs", precision=0, minimum=1, maximum=100) | |
| ds_min_num = gr.Number(value=CONFIG["dataset_min_words"], label="Min words per paragraph", precision=0, minimum=1, maximum=500) | |
| ds_max_num = gr.Number(value=CONFIG["dataset_max_tries"], label="Dataset scan limit", precision=0, minimum=10, maximum=100000) | |
| apply_btn = gr.Button("Apply Settings & Reload Model", variant="secondary") | |
| settings_status = gr.Textbox(label="Settings Status", interactive=False) | |
| with gr.Accordion("📚 Load from HuggingFace Dataset", open=False): | |
| with gr.Row(): | |
| ds_name = gr.Textbox(label="Dataset name", value="wikitext", placeholder="e.g. wikitext") | |
| ds_config = gr.Textbox(label="Config name (optional)", value="wikitext-2-raw-v1", placeholder="e.g. wikitext-2-raw-v1") | |
| ds_split = gr.Dropdown(["train", "validation", "test"], value="train", label="Split") | |
| ds_btn = gr.Button("Load Sample Paragraphs") | |
| ds_output = gr.Textbox(label="Loaded Paragraphs (copy one into Input Text)", lines=6, interactive=False) | |
| ds_btn.click(fetch_dataset, inputs=[ds_name, ds_config, ds_split], outputs=[ds_output]) | |
| apply_btn.click( | |
| apply_settings, | |
| inputs=[model_input, device_dd, max_tokens_num, temp_slider, top_p_slider, attn_impl_dd, ds_paras_num, ds_min_num, ds_max_num], | |
| outputs=[settings_status], | |
| ) | |
| # Hidden cache states | |
| cache_state = gr.State() | |
| words_state = gr.State() | |
| tok_positions_state = gr.State() | |
| attn_matrix_state = gr.State() | |
| word_ids_state = gr.State() | |
| summary_box = gr.Textbox(label="Pipeline Summary", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| max_expl_num = gr.Number( | |
| value=25, | |
| label="Auto LLM Explanations (top-N)", | |
| precision=0, | |
| minimum=0, | |
| maximum=9999, | |
| ) | |
| explain_btn = gr.Button("🤖 Generate Explanations", variant="secondary") | |
| with gr.Row(): | |
| json_dl_btn = gr.Button("📥 Download JSON") | |
| csv_dl_btn = gr.Button("📥 Download CSV") | |
| json_file = gr.File(label="JSON Download", visible=False) | |
| csv_file = gr.File(label="CSV Download", visible=False) | |
| with gr.Column(scale=2): | |
| expl_status = gr.Textbox(label="Explanation Status", interactive=False) | |
| pairs_df = gr.Dataframe( | |
| headers=["Word", "Neighbor", "Score", "Type", "Distance", "Context", "LLM Explanation"], | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| with gr.Tab("🌡️ Attention Heatmap"): | |
| heatmap_plot = gr.Plot(label="Full Attention Matrix (first 50 tokens)") | |
| with gr.Tab("📊 Word Frequency"): | |
| freq_plot = gr.Plot(label="Top 20 Word Frequencies") | |
| with gr.Tab("📈 Attention Signature"): | |
| with gr.Row(): | |
| sig_word_input = gr.Textbox(label="Word to visualize", placeholder="e.g. game") | |
| sig_update_btn = gr.Button("Update Signature") | |
| sig_plot = gr.Plot(label="Attention Signature (word -> all tokens)") | |
| with gr.Tab("🌌 3D Semantic Space"): | |
| space3d_plot = gr.Plot(label="3D PCA of attention signatures") | |
| with gr.Tab("🔗 3D Pair Space"): | |
| pair3d_plot = gr.Plot(label="3D PCA of pair midpoints") | |
| explain_btn.click( | |
| generate_explanations, | |
| inputs=[cache_state, max_expl_num], | |
| outputs=[expl_status, pairs_df] | |
| ) | |
| run_btn.click( | |
| run_pipeline, | |
| inputs=[input_text, thresh_slider, win_slider], | |
| outputs=[ | |
| summary_box, pairs_df, cache_state, | |
| heatmap_plot, freq_plot, sig_plot, space3d_plot, pair3d_plot, | |
| words_state, tok_positions_state, attn_matrix_state, word_ids_state, | |
| ] | |
| ) | |
| sig_update_btn.click( | |
| update_signature, | |
| inputs=[cache_state, sig_word_input, words_state, tok_positions_state, attn_matrix_state, word_ids_state], | |
| outputs=[sig_plot] | |
| ) | |
| # Download handlers | |
| def _save_json(text): | |
| if not text: | |
| return None | |
| path = "/tmp/results.json" | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| return path | |
| def _save_csv(text): | |
| if not text: | |
| return None | |
| path = "/tmp/results.csv" | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| return path | |
| json_dl_btn.click(download_json, inputs=[cache_state], outputs=[json_file]).then( | |
| _save_json, inputs=[json_file], outputs=[json_file] | |
| ) | |
| csv_dl_btn.click(download_csv, inputs=[cache_state], outputs=[csv_file]).then( | |
| _save_csv, inputs=[csv_file], outputs=[csv_file] | |
| ) | |
| demo.launch() | |