Bram van Es
Add application file
ef5aa3c
raw
history blame
18.3 kB
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
import numpy as np
import pandas as pd
import spacy
from spacy import displacy
import math
import warnings
try:
from config import DEFAULT_MODELS, MODEL_SETTINGS, VIZ_SETTINGS, PROCESSING_SETTINGS, UI_SETTINGS, ERROR_MESSAGES
except ImportError:
# Fallback configuration if config.py is not available
DEFAULT_MODELS = {
"decoder": ["gpt2", "distilgpt2"],
"encoder": ["bert-base-uncased", "distilbert-base-uncased"]
}
MODEL_SETTINGS = {"max_length": 512}
VIZ_SETTINGS = {
"max_perplexity_display": 100.0,
"color_scheme": {
"high_perplexity": {"r": 255, "g": 0, "b": 50},
"low_perplexity": {"r": 0, "g": 255, "b": 50}
},
"displacy_options": {"ents": ["PP"], "colors": {}}
}
PROCESSING_SETTINGS = {
"default_iterations": 1,
"max_iterations": 10,
"default_mlm_probability": 0.15,
"min_mlm_probability": 0.1,
"max_mlm_probability": 0.5,
"epsilon": 1e-10
}
UI_SETTINGS = {
"title": "πŸ“ˆ Perplexity Viewer",
"description": "Visualize per-token perplexity using color gradients.",
"examples": [
{"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder", "iterations": 1, "mlm_prob": 0.15},
{"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder", "iterations": 1, "mlm_prob": 0.15}
]
}
ERROR_MESSAGES = {
"empty_text": "Please enter some text to analyze.",
"model_load_error": "Error loading model {model_name}: {error}",
"processing_error": "Error processing text: {error}"
}
warnings.filterwarnings("ignore")
# Global variables to cache models
cached_models = {}
cached_tokenizers = {}
def load_model_and_tokenizer(model_name, model_type):
"""Load and cache model and tokenizer"""
cache_key = f"{model_name}_{model_type}"
if cache_key not in cached_models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add pad token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if model_type == "decoder":
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
else: # encoder
model = AutoModelForMaskedLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
model.eval() # Set to evaluation mode
cached_models[cache_key] = model
cached_tokenizers[cache_key] = tokenizer
return model, tokenizer
except Exception as e:
raise gr.Error(ERROR_MESSAGES["model_load_error"].format(model_name=model_name, error=str(e)))
return cached_models[cache_key], cached_tokenizers[cache_key]
def calculate_decoder_perplexity(text, model, tokenizer, iterations=1):
"""Calculate perplexity for decoder models (like GPT)"""
device = next(model.parameters()).device
perplexities = []
for iteration in range(iterations):
# Tokenize the text
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"])
input_ids = inputs.input_ids.to(device)
if input_ids.size(1) < 2:
raise gr.Error("Text is too short for perplexity calculation.")
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
perplexities.append(perplexity)
# Get token-level perplexities for the last iteration
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Shift logits and labels for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
# Calculate per-token losses
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
token_perplexities = torch.exp(token_losses).cpu().numpy()
# Get tokens (excluding the first one since we predict next tokens)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:])
# Clean up tokens for display
cleaned_tokens = []
for token in tokens:
if token.startswith('Δ '):
cleaned_tokens.append(token[1:]) # Remove Δ  prefix
elif token.startswith('##'):
cleaned_tokens.append(token[2:]) # Remove ## prefix
else:
cleaned_tokens.append(token)
return np.mean(perplexities), cleaned_tokens, token_perplexities
def calculate_encoder_perplexity(text, model, tokenizer, mlm_probability=0.15, iterations=1):
"""Calculate pseudo-perplexity for encoder models (like BERT) using MLM"""
device = next(model.parameters()).device
perplexities = []
for iteration in range(iterations):
# Tokenize the text
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"])
input_ids = inputs.input_ids.to(device)
if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP]
raise gr.Error("Text is too short for MLM perplexity calculation.")
# Create a copy for masking
masked_input_ids = input_ids.clone()
original_tokens = input_ids.clone()
# Randomly mask tokens (excluding special tokens)
seq_length = input_ids.size(1)
mask_indices = []
special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
for i in range(seq_length):
if input_ids[0, i].item() not in special_token_ids:
if torch.rand(1).item() < mlm_probability:
mask_indices.append(i)
masked_input_ids[0, i] = tokenizer.mask_token_id
if not mask_indices:
# If no tokens were masked, mask at least one non-special token
non_special_indices = [i for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids]
if non_special_indices:
mask_idx = torch.randint(0, len(non_special_indices), (1,)).item()
mask_indices = [non_special_indices[mask_idx]]
masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id
with torch.no_grad():
outputs = model(masked_input_ids)
predictions = outputs.logits
# Calculate perplexity for masked tokens
masked_token_losses = []
for idx in mask_indices:
target_id = original_tokens[0, idx]
pred_scores = predictions[0, idx]
prob = F.softmax(pred_scores, dim=-1)[target_id]
loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"])
masked_token_losses.append(loss.item())
if masked_token_losses:
avg_loss = np.mean(masked_token_losses)
perplexity = math.exp(avg_loss)
perplexities.append(perplexity)
# Calculate per-token pseudo-perplexity for visualization
with torch.no_grad():
token_perplexities = []
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
for i in range(len(tokens)):
if input_ids[0, i].item() in special_token_ids:
token_perplexities.append(1.0) # Low perplexity for special tokens
else:
masked_input = input_ids.clone()
original_token_id = input_ids[0, i]
masked_input[0, i] = tokenizer.mask_token_id
outputs = model(masked_input)
predictions = outputs.logits[0, i]
prob = F.softmax(predictions, dim=-1)[original_token_id]
token_perplexity = 1.0 / (prob.item() + PROCESSING_SETTINGS["epsilon"])
token_perplexities.append(token_perplexity)
# Clean up tokens for display
cleaned_tokens = []
for token in tokens:
if token.startswith('##'):
cleaned_tokens.append(token[2:])
else:
cleaned_tokens.append(token)
return np.mean(perplexities) if perplexities else float('inf'), cleaned_tokens, np.array(token_perplexities)
def create_visualization(tokens, perplexities):
"""Create displaCy visualization with color-coded perplexities"""
if len(tokens) == 0:
return "<p>No tokens to visualize.</p>"
# Cap perplexities for better visualization
max_perplexity = min(np.max(perplexities), VIZ_SETTINGS["max_perplexity_display"])
# Normalize perplexities to 0-1 range for color mapping
normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1)
# Create entities for displaCy
entities = []
text_parts = []
current_pos = 0
for i, (token, perp, norm_perp) in enumerate(zip(tokens, perplexities, normalized_perplexities)):
# Skip empty tokens
if not token.strip():
continue
# Clean token for display
clean_token = token.replace("</w>", "").strip()
if not clean_token:
continue
# Add space before token if it's not the first one and doesn't start with punctuation
if i > 0 and not clean_token[0] in ".,!?;:":
text_parts.append(" ")
current_pos += 1
text_parts.append(clean_token)
# Map perplexity to color
high_color = VIZ_SETTINGS["color_scheme"]["high_perplexity"]
low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"]
red = int(high_color["r"] * norm_perp + low_color["r"] * (1 - norm_perp))
green = int(high_color["g"] * norm_perp + low_color["g"] * (1 - norm_perp))
blue = int(high_color["b"] * norm_perp + low_color["b"] * (1 - norm_perp))
color = f"rgb({red}, {green}, {blue})"
entities.append({
"start": current_pos,
"end": current_pos + len(clean_token),
"label": f"{perp:.2f}",
"color": color
})
current_pos += len(clean_token)
# Join text parts
text = "".join(text_parts)
if not entities:
return "<p>No valid tokens found for visualization.</p>"
# Create displaCy data structure
doc_data = {
"text": text,
"ents": entities,
"title": "Per-token Perplexity Visualization"
}
try:
# Generate HTML
html = displacy.render(doc_data, style="ent", manual=True, options=VIZ_SETTINGS["displacy_options"])
return html
except Exception as e:
return f"<p>Error creating visualization: {str(e)}</p>"
def process_text(text, model_name, model_type, iterations, mlm_probability):
"""Main processing function"""
if not text.strip():
return ERROR_MESSAGES["empty_text"], "", pd.DataFrame()
try:
# Validate inputs
iterations = max(1, min(iterations, PROCESSING_SETTINGS["max_iterations"]))
mlm_probability = max(PROCESSING_SETTINGS["min_mlm_probability"],
min(mlm_probability, PROCESSING_SETTINGS["max_mlm_probability"]))
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(model_name, model_type)
# Calculate perplexity
if model_type == "decoder":
avg_perplexity, tokens, token_perplexities = calculate_decoder_perplexity(
text, model, tokenizer, iterations
)
else: # encoder
avg_perplexity, tokens, token_perplexities = calculate_encoder_perplexity(
text, model, tokenizer, mlm_probability, iterations
)
# Create visualization
viz_html = create_visualization(tokens, token_perplexities)
# Create summary
summary = f"""
### Analysis Results
**Model:** `{model_name}`
**Model Type:** {model_type.title()}
**Average Perplexity:** {avg_perplexity:.4f}
**Number of Tokens:** {len(tokens)}
**Iterations:** {iterations}
"""
if model_type == "encoder":
summary += f" \n**MLM Probability:** {mlm_probability}"
# Create detailed results table
df = pd.DataFrame({
'Token': tokens,
'Perplexity': [f"{p:.4f}" for p in token_perplexities]
})
return summary, viz_html, df
except Exception as e:
error_msg = ERROR_MESSAGES["processing_error"].format(error=str(e))
return error_msg, "", pd.DataFrame()
# Create Gradio interface
with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {UI_SETTINGS['title']}")
gr.Markdown(UI_SETTINGS["description"])
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter the text you want to analyze...",
lines=6,
max_lines=10
)
with gr.Row():
model_name = gr.Dropdown(
label="Model Name",
choices=DEFAULT_MODELS["decoder"] + DEFAULT_MODELS["encoder"],
value="gpt2",
allow_custom_value=True,
info="Select a model or enter a custom HuggingFace model name"
)
model_type = gr.Radio(
label="Model Type",
choices=["decoder", "encoder"],
value="decoder",
info="Decoder for causal LM, Encoder for masked LM"
)
with gr.Row():
iterations = gr.Slider(
label="Iterations",
minimum=1,
maximum=PROCESSING_SETTINGS["max_iterations"],
value=PROCESSING_SETTINGS["default_iterations"],
step=1,
info="Number of iterations to average over"
)
mlm_probability = gr.Slider(
label="MLM Probability",
minimum=PROCESSING_SETTINGS["min_mlm_probability"],
maximum=PROCESSING_SETTINGS["max_mlm_probability"],
value=PROCESSING_SETTINGS["default_mlm_probability"],
step=0.05,
visible=False,
info="Probability of masking tokens (encoder models only)"
)
analyze_btn = gr.Button("πŸ” Analyze Perplexity", variant="primary", size="lg")
with gr.Column(scale=3):
summary_output = gr.Markdown(label="Summary")
viz_output = gr.HTML(label="Perplexity Visualization")
# Full-width table
with gr.Row():
table_output = gr.Dataframe(
label="Detailed Token Results",
interactive=False,
wrap=True
)
# Update model dropdown based on type selection
def update_model_choices(model_type):
return gr.update(choices=DEFAULT_MODELS[model_type], value=DEFAULT_MODELS[model_type][0])
# Show/hide MLM probability based on model type
def toggle_mlm_visibility(model_type):
return gr.update(visible=(model_type == "encoder"))
model_type.change(
fn=lambda mt: [update_model_choices(mt), toggle_mlm_visibility(mt)],
inputs=[model_type],
outputs=[model_name, mlm_probability]
)
# Set up the analysis function
analyze_btn.click(
fn=process_text,
inputs=[text_input, model_name, model_type, iterations, mlm_probability],
outputs=[summary_output, viz_output, table_output]
)
# Add examples
with gr.Accordion("πŸ“ Example Texts", open=False):
examples_data = [
[ex["text"], ex["model"], ex["type"], ex["iterations"], ex["mlm_prob"]]
for ex in UI_SETTINGS["examples"]
]
gr.Examples(
examples=examples_data,
inputs=[text_input, model_name, model_type, iterations, mlm_probability],
outputs=[summary_output, viz_output, table_output],
fn=process_text,
cache_examples=False,
label="Click on an example to try it out:"
)
# Add footer with information
gr.Markdown("""
---
### πŸ“Š How it works:
- **Decoder Models** (GPT, etc.): Calculate true perplexity by measuring how well the model predicts the next token
- **Encoder Models** (BERT, etc.): Calculate pseudo-perplexity using masked language modeling (MLM)
- **Color Coding**: Red = High perplexity (uncertain), Green = Low perplexity (confident)
### ⚠️ Notes:
- First model load may take some time
- Models are cached after first use
- Very long texts are truncated to 512 tokens
- GPU acceleration is used when available
""")
if __name__ == "__main__":
try:
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=False
)
except Exception as e:
print(f"❌ Failed to launch app: {e}")
print("πŸ’‘ Try running with: python run.py")
# Fallback to basic launch
try:
demo.launch()
except Exception as fallback_error:
print(f"❌ Fallback launch also failed: {fallback_error}")
print("πŸ’‘ Try updating Gradio: pip install --upgrade gradio")