Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Demo script showing how MLM probability affects encoder model analysis | |
| """ | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def demo_mlm_probability_effect(): | |
| """Demonstrate how MLM probability affects the analysis""" | |
| print("🎭 MLM Probability Effect Demo") | |
| print("=" * 60) | |
| # Load a BERT model | |
| model_name = "distilbert-base-uncased" | |
| print(f"Loading {model_name}...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| model.eval() | |
| # Test text | |
| text = "The capital of France is Paris and it is beautiful." | |
| print(f"📝 Text: {text}") | |
| # Tokenize | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| input_ids = inputs.input_ids | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| print(f"🔤 Tokens: {tokens}") | |
| print() | |
| # Test different MLM probabilities | |
| mlm_probs = [0.1, 0.15, 0.3, 0.5, 0.8] | |
| for mlm_prob in mlm_probs: | |
| print(f"🎯 MLM Probability: {mlm_prob}") | |
| # Simulate the analysis process | |
| seq_length = input_ids.size(1) | |
| special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} | |
| # Count how many tokens would be analyzed | |
| analyzed_count = 0 | |
| analyzed_tokens = [] | |
| torch.manual_seed(42) # For reproducible results | |
| for i in range(seq_length): | |
| token = tokens[i] | |
| if input_ids[0, i].item() not in special_token_ids: | |
| if torch.rand(1).item() < mlm_prob: | |
| analyzed_count += 1 | |
| analyzed_tokens.append(f"'{token}'") | |
| total_content_tokens = sum(1 for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids) | |
| print(f" 📊 Analyzed: {analyzed_count}/{total_content_tokens} content tokens ({analyzed_count/total_content_tokens*100:.1f}%)") | |
| print(f" 🎯 Analyzed tokens: {', '.join(analyzed_tokens[:5])}" + (f" + {len(analyzed_tokens)-5} more" if len(analyzed_tokens) > 5 else "")) | |
| print() | |
| def simulate_perplexity_calculation(): | |
| """Simulate how different MLM probabilities affect perplexity calculation""" | |
| print("🧮 Perplexity Calculation Simulation") | |
| print("=" * 60) | |
| # Load model | |
| model_name = "distilbert-base-uncased" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| model.eval() | |
| text = "Machine learning is transforming artificial intelligence rapidly." | |
| inputs = tokenizer(text, return_tensors="pt") | |
| input_ids = inputs.input_ids | |
| print(f"📝 Text: {text}") | |
| print(f"🔤 Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}") | |
| print() | |
| mlm_probs = [0.15, 0.3, 0.5] | |
| for mlm_prob in mlm_probs: | |
| print(f"🎭 MLM Probability: {mlm_prob}") | |
| # Simulate multiple iterations | |
| iteration_results = [] | |
| for iteration in range(3): | |
| # Simulate masking | |
| masked_input_ids = input_ids.clone() | |
| original_tokens = input_ids.clone() | |
| seq_length = input_ids.size(1) | |
| mask_indices = [] | |
| special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} | |
| torch.manual_seed(42 + iteration) # Different seed per iteration | |
| for i in range(seq_length): | |
| if input_ids[0, i].item() not in special_token_ids: | |
| if torch.rand(1).item() < mlm_prob: | |
| mask_indices.append(i) | |
| masked_input_ids[0, i] = tokenizer.mask_token_id | |
| if not mask_indices: | |
| # Ensure at least one token is masked | |
| non_special_indices = [i for i in range(seq_length) | |
| if input_ids[0, i].item() not in special_token_ids] | |
| if non_special_indices: | |
| mask_idx = torch.randint(0, len(non_special_indices), (1,)).item() | |
| mask_indices = [non_special_indices[mask_idx]] | |
| masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id | |
| # Calculate pseudo-perplexity for masked tokens | |
| with torch.no_grad(): | |
| outputs = model(masked_input_ids) | |
| predictions = outputs.logits | |
| masked_token_losses = [] | |
| masked_tokens = [] | |
| for idx in mask_indices: | |
| target_id = original_tokens[0, idx] | |
| pred_scores = predictions[0, idx] | |
| prob = torch.softmax(pred_scores, dim=-1)[target_id] | |
| loss = -torch.log(prob + 1e-10) | |
| masked_token_losses.append(loss.item()) | |
| token = tokenizer.convert_ids_to_tokens([target_id])[0] | |
| masked_tokens.append(token) | |
| if masked_token_losses: | |
| avg_loss = sum(masked_token_losses) / len(masked_token_losses) | |
| perplexity = torch.exp(torch.tensor(avg_loss)).item() | |
| iteration_results.append(perplexity) | |
| print(f" Iteration {iteration + 1}: {len(mask_indices)} tokens masked") | |
| print(f" Masked: {', '.join(masked_tokens[:3])}" + (f" + {len(masked_tokens)-3} more" if len(masked_tokens) > 3 else "")) | |
| print(f" Pseudo-perplexity: {perplexity:.2f}") | |
| if iteration_results: | |
| avg_perplexity = sum(iteration_results) / len(iteration_results) | |
| print(f" 📊 Average pseudo-perplexity: {avg_perplexity:.2f}") | |
| print() | |
| def explain_mlm_probability(): | |
| """Explain what MLM probability actually does""" | |
| print("💡 Understanding MLM Probability") | |
| print("=" * 60) | |
| print(""" | |
| 🎭 **What is MLM Probability?** | |
| MLM (Masked Language Modeling) probability controls what fraction of tokens | |
| get randomly selected for detailed perplexity analysis. | |
| 📊 **How it works:** | |
| • Low MLM prob (0.15): Analyzes ~15% of tokens randomly | |
| • High MLM prob (0.5): Analyzes ~50% of tokens randomly | |
| • This affects both the average perplexity AND the visualization | |
| 🎯 **Why it matters:** | |
| • Higher MLM prob = More tokens analyzed = More complete picture | |
| • Lower MLM prob = Fewer tokens analyzed = Faster but less comprehensive | |
| • The randomness simulates real MLM training conditions | |
| 🌈 **Visual Effect:** | |
| • Analyzed tokens: Colored by their actual perplexity | |
| • Non-analyzed tokens: Shown in gray (baseline) | |
| • Try 0.15 vs 0.5 to see the difference! | |
| ⚖️ **Trade-offs:** | |
| • MLM 0.15: Fast, matches BERT training, but sparse analysis | |
| • MLM 0.5: Slower, more comprehensive, but artificial | |
| • MLM 0.8: Very slow, nearly complete, but unrealistic | |
| """) | |
| def main(): | |
| """Run MLM probability demonstration""" | |
| try: | |
| explain_mlm_probability() | |
| demo_mlm_probability_effect() | |
| simulate_perplexity_calculation() | |
| print("🎉 MLM Probability Demo Complete!") | |
| print("💡 Now try the app with different MLM probabilities:") | |
| print(" • Use 0.15 for standard analysis") | |
| print(" • Use 0.5 for more comprehensive analysis") | |
| print(" • Watch how the visualization changes!") | |
| except Exception as e: | |
| print(f"❌ Demo failed: {e}") | |
| print("💡 Make sure you have transformers installed: pip install transformers") | |
| if __name__ == "__main__": | |
| main() | |