hbe's picture
Upload app.py
030c25b verified
import gradio as gr
import re
import json
import random
import os
from collections import defaultdict, Counter
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from sklearn.decomposition import PCA
# ---------------------------------------------------------------------------
# Global Configuration (mutable via UI)
# ---------------------------------------------------------------------------
CONFIG = {
"model_id": "Qwen/Qwen2.5-0.5B-Instruct",
"device": "cuda" if torch.cuda.is_available() else "cpu",
"max_new_tokens": 120,
"temperature": 0.25,
"top_p": 0.9,
"attn_implementation": "eager",
"window": 7,
"threshold": 0.3,
"max_explanations": 25,
"dataset_paragraphs": 3,
"dataset_min_words": 20,
"dataset_max_tries": 1000,
}
_model_singleton = None
_tokenizer_singleton = None
def get_model(force_reload=False):
global _model_singleton, _tokenizer_singleton
if _model_singleton is None or force_reload:
print(f"[DEBUG] Loading {CONFIG['model_id']} on {CONFIG['device']} ...", flush=True)
_tokenizer_singleton = AutoTokenizer.from_pretrained(CONFIG["model_id"])
_model_singleton = AutoModelForCausalLM.from_pretrained(
CONFIG["model_id"],
device_map=None,
attn_implementation=CONFIG["attn_implementation"],
).to(CONFIG["device"])
_model_singleton.eval()
print("[DEBUG] Model loaded.", flush=True)
return _tokenizer_singleton, _model_singleton
def tokenizer():
return get_model()[0]
def model():
return get_model()[1]
# ---------------------------------------------------------------------------
# Dataset loader
# ---------------------------------------------------------------------------
def load_hf_text(dataset_name, config_name, split="train"):
try:
ds = load_dataset(dataset_name, config_name, split=split)
except Exception as e:
try:
ds = load_dataset(dataset_name, split=split)
except Exception:
return f"Error loading dataset: {e}"
candidates = []
for i, row in enumerate(ds):
if i >= CONFIG["dataset_max_tries"]:
break
text = row.get("text", "")
if not text or not text.strip():
continue
cleaned = clean_text(text)
tokens = cleaned.split()
if len(tokens) >= CONFIG["dataset_min_words"]:
candidates.append(cleaned)
if not candidates:
return "No valid paragraphs found in dataset."
n = min(CONFIG["dataset_paragraphs"], len(candidates))
return "\n\n---\n\n".join(random.sample(candidates, n))
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
CONTRACTIONS = {
"n't": " not", "'re": " are", "'s": " is", "'d": " would",
"'ll": " will", "'ve": " have", "'m": " am", "can't": "cannot",
"won't": "will not", "let's": "let us", "that's": "that is",
"who's": "who is", "what's": "what is", "it's": "it is",
"they're": "they are", "we're": "we are", "i'm": "i am",
"isn't": "is not", "aren't": "are not", "wasn't": "was not",
"haven't": "have not", "hasn't": "has not", "don't": "do not",
"doesn't": "does not", "didn't": "did not", "wouldn't": "would not",
"couldn't": "could not", "shouldn't": "should not", "wasn't": "was not",
"weren't": "were not", "hadn't": "had not", "hasn't": "has not",
"haven't": "have not", "won't": "will not", "wouldn't": "would not",
"can't": "cannot", "cannot": "can not", "i'd": "i would",
"you'd": "you would", "he'd": "he would", "she'd": "she would",
"it'd": "it would", "we'd": "we would", "they'd": "they would",
"i'll": "i will", "you'll": "you will", "he'll": "he will",
"she'll": "she will", "it'll": "it will", "we'll": "we will",
"they'll": "they will", "i've": "i have", "you've": "you have",
"we've": "we have", "they've": "they have", "aren't": "are not",
"isn't": "is not", "ain't": "am not", "let's": "let us",
"that's": "that is", "who's": "who is", "what's": "what is",
"here's": "here is", "there's": "there is", "where's": "where is",
"how's": "how is", "it's": "it is", "she's": "she is",
"he's": "he is", "that's": "that is", "there's": "there is",
"what's": "what is", "let's": "let us", "who's": "who is",
}
def expand_contractions(text):
for key in sorted(CONTRACTIONS.keys(), key=len, reverse=True):
text = re.sub(r"\b" + re.escape(key) + r"\b", CONTRACTIONS[key], text, flags=re.IGNORECASE)
return text
def clean_text(text):
text = text.lower()
text = expand_contractions(text)
text = re.sub(r"(\w)-(\w)", r"\1<HYPHEN>\2", text)
text = re.sub(r"[^\w\s<HYPHEN>]", " ", text)
text = text.replace("<HYPHEN>", "-")
text = re.sub(r"\s+", " ", text).strip()
return text
# ---------------------------------------------------------------------------
# Real attention extraction
# ---------------------------------------------------------------------------
def extract_attention_vectors(text, window_words=None):
if window_words is None:
window_words = CONFIG["window"]
tok = tokenizer()
mdl = model()
cleaned = clean_text(text)
words = cleaned.split()
if len(words) < 5:
return None, None, None, None, None, None, "Text too short after cleaning (need >=5 words).", None
encoding = tok(words, is_split_into_words=True, return_tensors="pt", add_special_tokens=False)
input_ids = encoding["input_ids"].to(CONFIG["device"])
word_ids = encoding.word_ids()
with torch.no_grad():
outputs = mdl(input_ids, output_attentions=True)
attn = torch.stack(outputs.attentions).mean(dim=0).mean(dim=1).squeeze(0).float()
T = attn.shape[0]
token_positions_by_word = [[] for _ in range(len(words))]
for t, wid in enumerate(word_ids):
if wid is not None and 0 <= wid < len(words):
token_positions_by_word[wid].append(t)
vectors = []
contexts = []
for w_idx in range(len(words)):
tok_pos = token_positions_by_word[w_idx]
if not tok_pos:
tok_pos = [min(w_idx, T - 1)]
v = attn[tok_pos, :].mean(dim=0)
ctx_start = max(0, w_idx - window_words)
ctx_end = min(len(words), w_idx + window_words + 1)
mask = torch.zeros(T, device=CONFIG["device"])
for t, wid in enumerate(word_ids):
if wid is not None and ctx_start <= wid < ctx_end:
mask[t] = 1.0
v_local = v * mask
norm = v_local.norm()
if norm > 1e-8:
v_local = v_local / norm
vectors.append(v_local.cpu())
contexts.append(" ".join(words[ctx_start:ctx_end]))
# Compute stats
vocab = Counter(words)
stats = {
"word_count": len(words),
"unique_words": len(vocab),
"token_count": T,
"top_words": vocab.most_common(10),
}
return words, vectors, contexts, token_positions_by_word, attn, word_ids, None, stats
def cosine_similarity(v1, v2):
return float((torch.dot(v1, v2) / (v1.norm() * v2.norm() + 1e-8)).item())
# ---------------------------------------------------------------------------
# Pair generation
# ---------------------------------------------------------------------------
def generate_all_pairs(words, vectors, contexts, window=None):
if window is None:
window = CONFIG["window"]
pairs = []
word_occurrences = defaultdict(list)
for i, w in enumerate(words):
word_occurrences[w].append(i)
vocab = sorted(set(words))
for w in vocab:
occs = word_occurrences[w]
for i in range(len(occs)):
for j in range(i + 1, len(occs)):
idx1, idx2 = occs[i], occs[j]
sim = cosine_similarity(vectors[idx1], vectors[idx2])
pairs.append({
"word": w, "neighbor": w, "similarity_score": sim,
"occurrence_indices": [idx1, idx2],
"context_sentences": [contexts[idx1], contexts[idx2]],
"pair_type": "self"
})
for w in vocab:
occs = word_occurrences[w]
for idx in occs:
ctx_start = max(0, idx - window)
ctx_end = min(len(words), idx + window + 1)
for n_idx in range(ctx_start, ctx_end):
if n_idx == idx:
continue
n = words[n_idx]
sim = cosine_similarity(vectors[idx], vectors[n_idx])
pairs.append({
"word": w, "neighbor": n, "similarity_score": sim,
"occurrence_indices": [idx],
"context_sentences": [contexts[idx]],
"pair_type": "neighbor", "distance": n_idx - idx
})
return pairs
def threshold_filter(pairs, threshold=None):
if threshold is None:
threshold = CONFIG["threshold"]
result = defaultdict(list)
for p in pairs:
if p["similarity_score"] > threshold:
result[p["word"]].append({
"neighbor": p["neighbor"],
"similarity_score": p["similarity_score"],
"occurrence_indices": p["occurrence_indices"],
"context_sentences": p["context_sentences"],
"pair_type": p.get("pair_type", "unknown"),
"distance": p.get("distance", None),
"semantic_explanation": ""
})
return dict(result)
# ---------------------------------------------------------------------------
# Auto LLM Explanation
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
"You are a precise semantic linguist. Given a word pair and their local context, "
"explain in ONE or TWO concise sentences how these two words semantically relate or influence each other. "
"Focus on semantic influence, topical association, or conceptual dependency: how the presence or meaning of "
"one word affects the choice or interpretation of the other in this specific context. "
"Do NOT focus on grammar, syntax, or word order. Be concrete and specific."
)
def build_llm_prompt(word, neighbor, sentences, pair_type):
ctx_block = "\n".join(f" Context {i+1}: {s}" for i, s in enumerate(sentences))
if pair_type == "self":
user = (
f"Word pair: '{word}' (two different occurrences of the same word)\n"
f"{ctx_block}\n"
f"Question: How do these two occurrences of '{word}' semantically relate or influence each other? "
f"Do they reinforce the same meaning, or does context shift their semantic role?"
)
else:
user = (
f"Word pair: '{word}' and '{neighbor}'\n"
f"{ctx_block}\n"
f"Question: How does '{word}' semantically influence '{neighbor}' (or vice versa) in this context? "
f"What shared topic, function, or conceptual dependency binds them?"
)
return [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}]
def llm_explain(word, neighbor, sentences, pair_type):
tok, mdl = get_model()
try:
messages = build_llm_prompt(word, neighbor, sentences, pair_type)
prompt_text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt_inputs = tok(prompt_text, return_tensors="pt")
prompt_len = prompt_inputs["input_ids"].shape[1]
gen_inputs = tok(prompt_text, return_tensors="pt").to(CONFIG["device"])
with torch.no_grad():
outputs = mdl.generate(
**gen_inputs,
max_new_tokens=CONFIG["max_new_tokens"],
temperature=CONFIG["temperature"],
top_p=CONFIG["top_p"],
do_sample=True,
pad_token_id=tok.eos_token_id,
)
total_len = outputs.shape[1]
if total_len <= prompt_len:
return "[No generation]"
generated_ids = outputs[0][prompt_len:]
clean_text = tok.decode(generated_ids, skip_special_tokens=True).strip()
return clean_text if clean_text else "[Empty after strip]"
except Exception as e:
return f"[Error: {e}]"
def auto_explain(filtered_dict, max_explanations=None):
if max_explanations is None:
max_explanations = CONFIG["max_explanations"]
all_recs = []
for word, recs in filtered_dict.items():
for rec in recs:
all_recs.append((word, rec))
all_recs.sort(key=lambda x: x[1]["similarity_score"], reverse=True)
print(f"[DEBUG] auto_explain: {len(all_recs)} total, generating {min(max_explanations, len(all_recs))}", flush=True)
for idx, (word, rec) in enumerate(all_recs[:max_explanations]):
print(f"[DEBUG] Explaining {idx+1}: '{word}' / '{rec['neighbor']}' score={rec['similarity_score']:.3f}", flush=True)
expl = llm_explain(word, rec["neighbor"], rec["context_sentences"], rec["pair_type"])
rec["semantic_explanation"] = expl
print(f"[DEBUG] Explanation length: {len(expl)} chars", flush=True)
time.sleep(0.02)
return filtered_dict
# ---------------------------------------------------------------------------
# Visualisations
# ---------------------------------------------------------------------------
def plot_attention_heatmap(attn_matrix, words, word_ids, max_tok=50):
attn_np = attn_matrix[:max_tok, :max_tok].cpu().numpy()
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(attn_np, cmap="viridis", aspect="auto")
ax.set_title("Attention Heatmap (first 50 tokens)")
ax.set_xlabel("Key token position")
ax.set_ylabel("Query token position")
fig.colorbar(im, ax=ax)
plt.tight_layout()
return fig
def plot_attention_signature(attn_matrix, tok_positions, words, word_ids, title_word):
if not tok_positions:
fig, ax = plt.subplots()
ax.text(0.5, 0.5, f"No tokens found for '{title_word}'", ha="center", va="center")
return fig
vec = attn_matrix[tok_positions, :].mean(dim=0).cpu().numpy()
T = len(vec)
fig, ax = plt.subplots(figsize=(14, 4))
ax.bar(range(T), vec, width=1.0, color="steelblue")
ax.set_title(f"Attention Signature: '{title_word}' -> all tokens")
ax.set_xlabel("Token position")
ax.set_ylabel("Attention weight")
step = max(1, T // 20)
ax.set_xticks(range(0, T, step))
ax.set_xticklabels([str(i) for i in range(0, T, step)], rotation=45)
plt.tight_layout()
return fig
def plot_word_frequency(words):
vocab = Counter(words)
top20 = vocab.most_common(20)
wds, counts = zip(*top20) if top20 else ([], [])
fig, ax = plt.subplots(figsize=(12, 5))
ax.barh(range(len(wds)), counts[::-1], color="teal")
ax.set_yticks(range(len(wds)))
ax.set_yticklabels(wds[::-1])
ax.set_xlabel("Frequency")
ax.set_title("Top 20 Word Frequencies")
plt.tight_layout()
return fig
def plot_3d_semantic_space(words, vectors, filtered_dict):
if not vectors:
return go.Figure()
X = torch.stack(vectors).float().numpy()
pca = PCA(n_components=3)
coords = pca.fit_transform(X)
hover_texts = []
for i, w in enumerate(words):
ctx = " ".join(words[max(0, i-3):min(len(words), i+4)])
hover_texts.append(f"{w}<br>idx={i}<br>ctx: {ctx[:60]}...")
filtered_words = set()
for w, recs in filtered_dict.items():
for rec in recs:
filtered_words.add(w)
filtered_words.add(rec["neighbor"])
colors = ["#e74c3c" if w in filtered_words else "#3498db" for w in words]
sizes = [8 if w in filtered_words else 4 for w in words]
fig = go.Figure(data=[go.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers",
marker=dict(size=sizes, color=colors, opacity=0.8),
text=words,
hovertext=hover_texts,
hoverinfo="text",
)])
fig.update_layout(
title="3D Semantic Space (PCA of attention signatures)<br>Red/large = words in retained pairs",
scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3"),
width=800,
height=600,
)
return fig
def plot_pair_similarity_3d(filtered_dict, vectors, words):
if not filtered_dict:
return go.Figure()
all_recs = []
for w, recs in filtered_dict.items():
for rec in recs:
all_recs.append((w, rec))
all_recs.sort(key=lambda x: x[1]["similarity_score"], reverse=True)
all_recs = all_recs[:100]
if not all_recs:
return go.Figure()
X = torch.stack(vectors).float().numpy()
pca = PCA(n_components=3)
pca.fit(X)
midpoints = []
scores = []
labels = []
for w, rec in all_recs:
idxs = rec["occurrence_indices"]
if len(idxs) >= 2:
v1 = pca.transform(vectors[idxs[0]].unsqueeze(0).float().numpy())[0]
v2 = pca.transform(vectors[idxs[1]].unsqueeze(0).float().numpy())[0]
else:
n = rec["neighbor"]
n_idx = None
for i, word in enumerate(words):
if word == n:
n_idx = i
break
if n_idx is None:
n_idx = idxs[0]
v1 = pca.transform(vectors[idxs[0]].unsqueeze(0).float().numpy())[0]
v2 = pca.transform(vectors[n_idx].unsqueeze(0).float().numpy())[0]
mid = (v1 + v2) / 2
midpoints.append(mid)
scores.append(rec["similarity_score"])
labels.append(f"{w}{rec['neighbor']}<br>score={rec['similarity_score']:.3f}")
midpoints = np.array(midpoints)
fig = go.Figure(data=[go.Scatter3d(
x=midpoints[:, 0],
y=midpoints[:, 1],
z=midpoints[:, 2],
mode="markers",
marker=dict(
size=[max(3, s * 15) for s in scores],
color=scores,
colorscale="Plasma",
colorbar=dict(title="Similarity"),
opacity=0.85,
),
text=labels,
hoverinfo="text",
)])
fig.update_layout(
title="3D Pair Similarity Space (PCA midpoints, top 100 pairs)",
scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3"),
width=800,
height=600,
)
return fig
# ---------------------------------------------------------------------------
# Download helpers
# ---------------------------------------------------------------------------
def json_download(filtered_dict):
return json.dumps(filtered_dict, indent=2, ensure_ascii=False)
def csv_download(filtered_dict):
lines = ["word,neighbor,similarity_score,pair_type,distance,context,semantic_explanation"]
for word, recs in filtered_dict.items():
for rec in recs:
ctx = rec["context_sentences"][0].replace('"', '""') if rec["context_sentences"] else ""
expl = rec.get("semantic_explanation", "").replace('"', '""')
lines.append(
f'"{word}","{rec["neighbor"]}",{rec["similarity_score"]:.6f},'
f'"{rec["pair_type"]}",{rec.get("distance", "")},'
f'"{ctx}","{expl}"'
)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Gradio handlers
# ---------------------------------------------------------------------------
DEFAULT_TEXT = (
"Senjō no Valkyria 3: Unrecorded Chronicles (Japanese: 戦場のヴァルキュリア3) "
"is a tactical role-playing video game developed by Sega and Media.Vision for the PlayStation Portable. "
"Released in January 2011 in Japan, it is the third game in the Valkyria Chronicles series. "
"The game uses the same fusion of tactical and real-time action as its predecessors, "
"and introduces new characters and a darker storyline about a penal military unit."
)
def run_pipeline(text, threshold, window):
print("[DEBUG] run_pipeline started", flush=True)
CONFIG["threshold"] = threshold
CONFIG["window"] = int(window)
result = extract_attention_vectors(text, window)
words, vectors, contexts, tok_positions, attn_matrix, word_ids, err, stats = result
if err:
return err, [], "", None, None, None, None, None, None, None, None, None
all_pairs = generate_all_pairs(words, vectors, contexts, window=window)
filtered = threshold_filter(all_pairs, threshold=threshold)
total = sum(len(v) for v in filtered.values())
stats_str = (
f"Words: {stats['word_count']} | Unique: {stats['unique_words']} | Tokens: {stats['token_count']} | "
f"Raw pairs: {len(all_pairs)} | Retained: {total} | "
f"Top words: {', '.join(f'{w}({c})' for w, c in stats['top_words'][:5])}"
)
print(f"[DEBUG] {stats_str}", flush=True)
rows = []
for w, recs in filtered.items():
for r in recs[:20]:
rows.append([
w,
r["neighbor"],
f"{r['similarity_score']:.3f}",
r["pair_type"],
r.get("distance", ""),
r["context_sentences"][0][:70] + "...",
"(click Generate Explanations)",
])
heatmap_fig = plot_attention_heatmap(attn_matrix, words, word_ids)
freq_fig = plot_word_frequency(words)
sig_fig = plot_attention_signature(
attn_matrix,
tok_positions[0] if tok_positions else [],
words, word_ids, words[0] if words else ""
)
space3d_fig = plot_3d_semantic_space(words, vectors, filtered)
pair3d_fig = plot_pair_similarity_3d(filtered, vectors, words)
cache_json = json.dumps({
"words": words,
"contexts": contexts,
"filtered": filtered,
"stats": stats,
}, indent=2, ensure_ascii=False)
return (
stats_str, rows, cache_json,
heatmap_fig, freq_fig, sig_fig, space3d_fig, pair3d_fig,
words, tok_positions, attn_matrix, word_ids
)
def generate_explanations(cache_json, max_expl):
if not cache_json or cache_json.strip() == "":
return "Run pipeline first.", []
try:
cache = json.loads(cache_json)
except Exception:
return "Invalid cache.", []
filtered = cache["filtered"]
max_expl = int(max_expl)
CONFIG["max_explanations"] = max_expl
print(f"[DEBUG] generate_explanations called, max={max_expl}", flush=True)
filtered = auto_explain(filtered, max_explanations=max_expl)
rows = []
for w, recs in filtered.items():
for r in recs[:20]:
expl = r.get("semantic_explanation", "")
rows.append([
w,
r["neighbor"],
f"{r['similarity_score']:.3f}",
r["pair_type"],
r.get("distance", ""),
r["context_sentences"][0][:70] + "...",
expl,
])
return "Explanations generated.", rows
def update_signature(cache_json, word_input, words_state, tok_positions_state, attn_matrix_state, word_ids_state):
if not cache_json:
return None
try:
cache = json.loads(cache_json)
words = cache.get("words", words_state or [])
tok_positions = tok_positions_state or []
attn_matrix = attn_matrix_state
word_ids = word_ids_state or []
except Exception:
words = words_state or []
tok_positions = tok_positions_state or []
attn_matrix = attn_matrix_state
word_ids = word_ids_state or []
if not words or attn_matrix is None:
return None
word_input = word_input.lower().strip()
if word_input not in words:
return None
idx = words.index(word_input)
tp = tok_positions[idx] if idx < len(tok_positions) else []
return plot_attention_signature(attn_matrix, tp, words, word_ids, word_input)
def fetch_dataset(dataset_name, config_name, split):
result = load_hf_text(dataset_name, config_name or None, split)
return result if not result.startswith("Error") else result
def apply_settings(model_id, device, max_tokens, temperature, top_p, attn_impl, ds_paras, ds_min, ds_max):
old_model = CONFIG["model_id"]
CONFIG["model_id"] = model_id.strip() or old_model
CONFIG["device"] = device
CONFIG["max_new_tokens"] = int(max_tokens)
CONFIG["temperature"] = float(temperature)
CONFIG["top_p"] = float(top_p)
CONFIG["attn_implementation"] = attn_impl
CONFIG["dataset_paragraphs"] = int(ds_paras)
CONFIG["dataset_min_words"] = int(ds_min)
CONFIG["dataset_max_tries"] = int(ds_max)
force_reload = (old_model != CONFIG["model_id"])
if force_reload:
global _model_singleton, _tokenizer_singleton
_model_singleton = None
_tokenizer_singleton = None
try:
get_model(force_reload=True)
return f"Settings applied. Model '{CONFIG['model_id']}' loaded on {CONFIG['device']}."
except Exception as e:
return f"Error loading model: {e}"
return (
f"Settings applied. Model: {CONFIG['model_id']} | Device: {CONFIG['device']} | "
f"Max tokens: {CONFIG['max_new_tokens']} | Temp: {CONFIG['temperature']} | Top-p: {CONFIG['top_p']}"
)
def download_json(cache_json):
if not cache_json:
return None
try:
cache = json.loads(cache_json)
filtered = cache.get("filtered", {})
return json_download(filtered)
except Exception:
return None
def download_csv(cache_json):
if not cache_json:
return None
try:
cache = json.loads(cache_json)
filtered = cache.get("filtered", {})
return csv_download(filtered)
except Exception:
return None
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Semantic Attention Explorer", css="""
.dataframe-wrap { white-space: pre-wrap !important; }
""") as demo:
gr.Markdown("# 🔍 Semantic Attention Explorer")
gr.Markdown(
"Extract **real neural attention** from a causal LM, compute **cosine similarity** between "
"centered attention signatures, and auto-generate **LLM semantic explanations** for retained pairs."
)
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(label="Input Text", lines=8, value=DEFAULT_TEXT)
with gr.Column(scale=1):
thresh_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Similarity Threshold")
win_slider = gr.Slider(1, 15, value=7, step=1, label="Context Window (words)")
run_btn = gr.Button("▶️ Run Pipeline", variant="primary")
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
model_input = gr.Textbox(
label="Model ID",
value=CONFIG["model_id"],
placeholder="e.g. Qwen/Qwen2.5-0.5B-Instruct",
)
device_dd = gr.Dropdown(
["cpu", "cuda"] + (["mps"] if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else []),
value=CONFIG["device"],
label="Device",
)
attn_impl_dd = gr.Dropdown(["eager", "sdpa", "flash_attention_2"], value=CONFIG["attn_implementation"], label="Attention Implementation")
with gr.Row():
max_tokens_num = gr.Number(value=CONFIG["max_new_tokens"], label="Max New Tokens", precision=0, minimum=1, maximum=2048)
temp_slider = gr.Slider(0.0, 2.0, value=CONFIG["temperature"], step=0.05, label="Temperature")
top_p_slider = gr.Slider(0.0, 1.0, value=CONFIG["top_p"], step=0.05, label="Top-p")
with gr.Row():
ds_paras_num = gr.Number(value=CONFIG["dataset_paragraphs"], label="Dataset paragraphs", precision=0, minimum=1, maximum=100)
ds_min_num = gr.Number(value=CONFIG["dataset_min_words"], label="Min words per paragraph", precision=0, minimum=1, maximum=500)
ds_max_num = gr.Number(value=CONFIG["dataset_max_tries"], label="Dataset scan limit", precision=0, minimum=10, maximum=100000)
apply_btn = gr.Button("Apply Settings & Reload Model", variant="secondary")
settings_status = gr.Textbox(label="Settings Status", interactive=False)
with gr.Accordion("📚 Load from HuggingFace Dataset", open=False):
with gr.Row():
ds_name = gr.Textbox(label="Dataset name", value="wikitext", placeholder="e.g. wikitext")
ds_config = gr.Textbox(label="Config name (optional)", value="wikitext-2-raw-v1", placeholder="e.g. wikitext-2-raw-v1")
ds_split = gr.Dropdown(["train", "validation", "test"], value="train", label="Split")
ds_btn = gr.Button("Load Sample Paragraphs")
ds_output = gr.Textbox(label="Loaded Paragraphs (copy one into Input Text)", lines=6, interactive=False)
ds_btn.click(fetch_dataset, inputs=[ds_name, ds_config, ds_split], outputs=[ds_output])
apply_btn.click(
apply_settings,
inputs=[model_input, device_dd, max_tokens_num, temp_slider, top_p_slider, attn_impl_dd, ds_paras_num, ds_min_num, ds_max_num],
outputs=[settings_status],
)
# Hidden cache states
cache_state = gr.State()
words_state = gr.State()
tok_positions_state = gr.State()
attn_matrix_state = gr.State()
word_ids_state = gr.State()
summary_box = gr.Textbox(label="Pipeline Summary", interactive=False)
with gr.Row():
with gr.Column(scale=1):
max_expl_num = gr.Number(
value=25,
label="Auto LLM Explanations (top-N)",
precision=0,
minimum=0,
maximum=9999,
)
explain_btn = gr.Button("🤖 Generate Explanations", variant="secondary")
with gr.Row():
json_dl_btn = gr.Button("📥 Download JSON")
csv_dl_btn = gr.Button("📥 Download CSV")
json_file = gr.File(label="JSON Download", visible=False)
csv_file = gr.File(label="CSV Download", visible=False)
with gr.Column(scale=2):
expl_status = gr.Textbox(label="Explanation Status", interactive=False)
pairs_df = gr.Dataframe(
headers=["Word", "Neighbor", "Score", "Type", "Distance", "Context", "LLM Explanation"],
interactive=False,
wrap=True,
)
with gr.Tab("🌡️ Attention Heatmap"):
heatmap_plot = gr.Plot(label="Full Attention Matrix (first 50 tokens)")
with gr.Tab("📊 Word Frequency"):
freq_plot = gr.Plot(label="Top 20 Word Frequencies")
with gr.Tab("📈 Attention Signature"):
with gr.Row():
sig_word_input = gr.Textbox(label="Word to visualize", placeholder="e.g. game")
sig_update_btn = gr.Button("Update Signature")
sig_plot = gr.Plot(label="Attention Signature (word -> all tokens)")
with gr.Tab("🌌 3D Semantic Space"):
space3d_plot = gr.Plot(label="3D PCA of attention signatures")
with gr.Tab("🔗 3D Pair Space"):
pair3d_plot = gr.Plot(label="3D PCA of pair midpoints")
explain_btn.click(
generate_explanations,
inputs=[cache_state, max_expl_num],
outputs=[expl_status, pairs_df]
)
run_btn.click(
run_pipeline,
inputs=[input_text, thresh_slider, win_slider],
outputs=[
summary_box, pairs_df, cache_state,
heatmap_plot, freq_plot, sig_plot, space3d_plot, pair3d_plot,
words_state, tok_positions_state, attn_matrix_state, word_ids_state,
]
)
sig_update_btn.click(
update_signature,
inputs=[cache_state, sig_word_input, words_state, tok_positions_state, attn_matrix_state, word_ids_state],
outputs=[sig_plot]
)
# Download handlers
def _save_json(text):
if not text:
return None
path = "/tmp/results.json"
with open(path, "w", encoding="utf-8") as f:
f.write(text)
return path
def _save_csv(text):
if not text:
return None
path = "/tmp/results.csv"
with open(path, "w", encoding="utf-8") as f:
f.write(text)
return path
json_dl_btn.click(download_json, inputs=[cache_state], outputs=[json_file]).then(
_save_json, inputs=[json_file], outputs=[json_file]
)
csv_dl_btn.click(download_csv, inputs=[cache_state], outputs=[csv_file]).then(
_save_csv, inputs=[csv_file], outputs=[csv_file]
)
demo.launch()