PerplexityViewer / mlm_demo.py
Bram van Es
bla
ef12530
#!/usr/bin/env python3
"""
Demo script showing how MLM probability affects encoder model analysis
"""
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import warnings
warnings.filterwarnings("ignore")
def demo_mlm_probability_effect():
"""Demonstrate how MLM probability affects the analysis"""
print("🎭 MLM Probability Effect Demo")
print("=" * 60)
# Load a BERT model
model_name = "distilbert-base-uncased"
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
# Test text
text = "The capital of France is Paris and it is beautiful."
print(f"📝 Text: {text}")
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs.input_ids
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(f"🔤 Tokens: {tokens}")
print()
# Test different MLM probabilities
mlm_probs = [0.1, 0.15, 0.3, 0.5, 0.8]
for mlm_prob in mlm_probs:
print(f"🎯 MLM Probability: {mlm_prob}")
# Simulate the analysis process
seq_length = input_ids.size(1)
special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
# Count how many tokens would be analyzed
analyzed_count = 0
analyzed_tokens = []
torch.manual_seed(42) # For reproducible results
for i in range(seq_length):
token = tokens[i]
if input_ids[0, i].item() not in special_token_ids:
if torch.rand(1).item() < mlm_prob:
analyzed_count += 1
analyzed_tokens.append(f"'{token}'")
total_content_tokens = sum(1 for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids)
print(f" 📊 Analyzed: {analyzed_count}/{total_content_tokens} content tokens ({analyzed_count/total_content_tokens*100:.1f}%)")
print(f" 🎯 Analyzed tokens: {', '.join(analyzed_tokens[:5])}" + (f" + {len(analyzed_tokens)-5} more" if len(analyzed_tokens) > 5 else ""))
print()
def simulate_perplexity_calculation():
"""Simulate how different MLM probabilities affect perplexity calculation"""
print("🧮 Perplexity Calculation Simulation")
print("=" * 60)
# Load model
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
text = "Machine learning is transforming artificial intelligence rapidly."
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs.input_ids
print(f"📝 Text: {text}")
print(f"🔤 Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
print()
mlm_probs = [0.15, 0.3, 0.5]
for mlm_prob in mlm_probs:
print(f"🎭 MLM Probability: {mlm_prob}")
# Simulate multiple iterations
iteration_results = []
for iteration in range(3):
# Simulate masking
masked_input_ids = input_ids.clone()
original_tokens = input_ids.clone()
seq_length = input_ids.size(1)
mask_indices = []
special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
torch.manual_seed(42 + iteration) # Different seed per iteration
for i in range(seq_length):
if input_ids[0, i].item() not in special_token_ids:
if torch.rand(1).item() < mlm_prob:
mask_indices.append(i)
masked_input_ids[0, i] = tokenizer.mask_token_id
if not mask_indices:
# Ensure at least one token is masked
non_special_indices = [i for i in range(seq_length)
if input_ids[0, i].item() not in special_token_ids]
if non_special_indices:
mask_idx = torch.randint(0, len(non_special_indices), (1,)).item()
mask_indices = [non_special_indices[mask_idx]]
masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id
# Calculate pseudo-perplexity for masked tokens
with torch.no_grad():
outputs = model(masked_input_ids)
predictions = outputs.logits
masked_token_losses = []
masked_tokens = []
for idx in mask_indices:
target_id = original_tokens[0, idx]
pred_scores = predictions[0, idx]
prob = torch.softmax(pred_scores, dim=-1)[target_id]
loss = -torch.log(prob + 1e-10)
masked_token_losses.append(loss.item())
token = tokenizer.convert_ids_to_tokens([target_id])[0]
masked_tokens.append(token)
if masked_token_losses:
avg_loss = sum(masked_token_losses) / len(masked_token_losses)
perplexity = torch.exp(torch.tensor(avg_loss)).item()
iteration_results.append(perplexity)
print(f" Iteration {iteration + 1}: {len(mask_indices)} tokens masked")
print(f" Masked: {', '.join(masked_tokens[:3])}" + (f" + {len(masked_tokens)-3} more" if len(masked_tokens) > 3 else ""))
print(f" Pseudo-perplexity: {perplexity:.2f}")
if iteration_results:
avg_perplexity = sum(iteration_results) / len(iteration_results)
print(f" 📊 Average pseudo-perplexity: {avg_perplexity:.2f}")
print()
def explain_mlm_probability():
"""Explain what MLM probability actually does"""
print("💡 Understanding MLM Probability")
print("=" * 60)
print("""
🎭 **What is MLM Probability?**
MLM (Masked Language Modeling) probability controls what fraction of tokens
get randomly selected for detailed perplexity analysis.
📊 **How it works:**
• Low MLM prob (0.15): Analyzes ~15% of tokens randomly
• High MLM prob (0.5): Analyzes ~50% of tokens randomly
• This affects both the average perplexity AND the visualization
🎯 **Why it matters:**
• Higher MLM prob = More tokens analyzed = More complete picture
• Lower MLM prob = Fewer tokens analyzed = Faster but less comprehensive
• The randomness simulates real MLM training conditions
🌈 **Visual Effect:**
• Analyzed tokens: Colored by their actual perplexity
• Non-analyzed tokens: Shown in gray (baseline)
• Try 0.15 vs 0.5 to see the difference!
⚖️ **Trade-offs:**
• MLM 0.15: Fast, matches BERT training, but sparse analysis
• MLM 0.5: Slower, more comprehensive, but artificial
• MLM 0.8: Very slow, nearly complete, but unrealistic
""")
def main():
"""Run MLM probability demonstration"""
try:
explain_mlm_probability()
demo_mlm_probability_effect()
simulate_perplexity_calculation()
print("🎉 MLM Probability Demo Complete!")
print("💡 Now try the app with different MLM probabilities:")
print(" • Use 0.15 for standard analysis")
print(" • Use 0.5 for more comprehensive analysis")
print(" • Watch how the visualization changes!")
except Exception as e:
print(f"❌ Demo failed: {e}")
print("💡 Make sure you have transformers installed: pip install transformers")
if __name__ == "__main__":
main()