PerplexityViewer / demo.py
Bram van Es
bla
ef12530
#!/usr/bin/env python3
"""
Demo script for PerplexityViewer - shows core functionality without GUI
"""
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
import warnings
warnings.filterwarnings("ignore")
def demo_decoder_perplexity():
"""Demo decoder model perplexity calculation"""
print("="*60)
print("πŸ€– Decoder Model Demo (GPT-2)")
print("="*60)
# Load model
model_name = "distilgpt2"
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
# Test texts
test_texts = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is revolutionizing artificial intelligence.",
"Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo.",
"The capital of France is Paris."
]
for i, text in enumerate(test_texts, 1):
print(f"\nπŸ“ Text {i}: {text}")
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs.input_ids
# Calculate perplexity
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
print(f" πŸ’― Perplexity: {perplexity:.2f}")
# Get token-level details
tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:]) # Skip first token
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
token_perplexities = torch.exp(token_losses).cpu().numpy()
print(" 🎯 Token details:")
for token, pp in zip(tokens[:5], token_perplexities[:5]): # Show first 5
clean_token = token.replace('Δ ', ' ').replace('##', '')
color = '🟒' if pp < 3 else '🟑' if pp < 10 else 'πŸ”΄'
print(f" {color} '{clean_token}': {pp:.2f}")
if len(tokens) > 5:
print(f" ... and {len(tokens) - 5} more tokens")
def demo_encoder_perplexity():
"""Demo encoder model pseudo-perplexity calculation"""
print("\n" + "="*60)
print("πŸ€– Encoder Model Demo (DistilBERT)")
print("="*60)
# Load model
model_name = "distilbert-base-uncased"
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
# Test texts
test_texts = [
"The capital of France is Paris.",
"Python is a programming language.",
"The weather today is beautiful.",
"Machine learning requires large datasets."
]
mlm_probability = 0.15
for i, text in enumerate(test_texts, 1):
print(f"\nπŸ“ Text {i}: {text}")
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs.input_ids
# Create masked version
masked_input_ids = input_ids.clone()
original_tokens = input_ids.clone()
# Randomly mask tokens (excluding special tokens)
seq_length = input_ids.size(1)
mask_indices = []
special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
for j in range(seq_length):
if input_ids[0, j].item() not in special_token_ids:
if torch.rand(1).item() < mlm_probability:
mask_indices.append(j)
masked_input_ids[0, j] = tokenizer.mask_token_id
if not mask_indices: # Ensure at least one token is masked
non_special_indices = [j for j in range(seq_length) if input_ids[0, j].item() not in special_token_ids]
if non_special_indices:
mask_idx = torch.randint(0, len(non_special_indices), (1,)).item()
mask_indices = [non_special_indices[mask_idx]]
masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id
# Calculate pseudo-perplexity
with torch.no_grad():
outputs = model(masked_input_ids)
predictions = outputs.logits
masked_token_losses = []
for idx in mask_indices:
target_id = original_tokens[0, idx]
pred_scores = predictions[0, idx]
prob = torch.softmax(pred_scores, dim=-1)[target_id]
loss = -torch.log(prob + 1e-10)
masked_token_losses.append(loss.item())
if masked_token_losses:
avg_loss = np.mean(masked_token_losses)
pseudo_perplexity = np.exp(avg_loss)
else:
pseudo_perplexity = float('inf')
print(f" πŸ’― Pseudo-perplexity: {pseudo_perplexity:.2f}")
print(f" 🎭 Masked {len(mask_indices)} tokens")
# Show some token-level pseudo-perplexities
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(" 🎯 Sample token pseudo-perplexities:")
with torch.no_grad():
sample_indices = list(range(1, min(6, len(tokens)-1))) # Skip [CLS] and [SEP]
for idx in sample_indices:
if input_ids[0, idx].item() not in special_token_ids:
masked_input = input_ids.clone()
original_token_id = input_ids[0, idx]
masked_input[0, idx] = tokenizer.mask_token_id
outputs = model(masked_input)
predictions = outputs.logits[0, idx]
prob = torch.softmax(predictions, dim=-1)[original_token_id]
token_pseudo_perplexity = 1.0 / (prob.item() + 1e-10)
clean_token = tokens[idx].replace('##', '')
color = '🟒' if token_pseudo_perplexity < 5 else '🟑' if token_pseudo_perplexity < 20 else 'πŸ”΄'
print(f" {color} '{clean_token}': {token_pseudo_perplexity:.2f}")
def demo_comparison():
"""Compare perplexity across different model types"""
print("\n" + "="*60)
print("πŸ”¬ Model Comparison Demo")
print("="*60)
test_text = "The quick brown fox jumps over the lazy dog."
print(f"πŸ“ Comparing models on: {test_text}")
models_to_test = [
("distilgpt2", "decoder"),
("distilbert-base-uncased", "encoder")
]
results = []
for model_name, model_type in models_to_test:
print(f"\nπŸ€– Testing {model_name} ({model_type})...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if model_type == "decoder":
model = AutoModelForCausalLM.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
else:
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
inputs = tokenizer(test_text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs.input_ids
if model_type == "decoder":
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
else: # encoder
# Quick pseudo-perplexity calculation
masked_input_ids = input_ids.clone()
seq_length = input_ids.size(1)
# Mask middle token
if seq_length > 2:
middle_idx = seq_length // 2
masked_input_ids[0, middle_idx] = tokenizer.mask_token_id
with torch.no_grad():
outputs = model(masked_input_ids)
predictions = outputs.logits[0, middle_idx]
prob = torch.softmax(predictions, dim=-1)[input_ids[0, middle_idx]]
perplexity = 1.0 / (prob.item() + 1e-10)
else:
perplexity = float('inf')
results.append((model_name, model_type, perplexity))
print(f" βœ… Perplexity: {perplexity:.2f}")
except Exception as e:
print(f" ❌ Error: {e}")
results.append((model_name, model_type, float('inf')))
print(f"\nπŸ“Š Summary for '{test_text}':")
for model_name, model_type, perplexity in results:
if perplexity != float('inf'):
confidence = "High" if perplexity < 5 else "Medium" if perplexity < 15 else "Low"
print(f" β€’ {model_name} ({model_type}): {perplexity:.2f} - {confidence} confidence")
else:
print(f" β€’ {model_name} ({model_type}): Failed")
def main():
"""Run all demos"""
print("🎭 PerplexityViewer Core Functionality Demo")
print("This demo shows how perplexity calculation works under the hood")
try:
demo_decoder_perplexity()
demo_encoder_perplexity()
demo_comparison()
print("\n" + "="*60)
print("πŸŽ‰ Demo completed successfully!")
print("πŸ’‘ To try the interactive web interface, run: python run.py")
print("="*60)
except KeyboardInterrupt:
print("\nπŸ‘‹ Demo interrupted by user")
except Exception as e:
print(f"\n❌ Demo failed with error: {e}")
print("πŸ’‘ Make sure you have installed all dependencies: pip install -r requirements.txt")
if __name__ == "__main__":
main()