Spaces:
Sleeping
Sleeping
File size: 7,671 Bytes
ef12530 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
#!/usr/bin/env python3
"""
Demo script showing how MLM probability affects encoder model analysis
"""
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import warnings
warnings.filterwarnings("ignore")
def demo_mlm_probability_effect():
"""Demonstrate how MLM probability affects the analysis"""
print("๐ญ MLM Probability Effect Demo")
print("=" * 60)
# Load a BERT model
model_name = "distilbert-base-uncased"
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
# Test text
text = "The capital of France is Paris and it is beautiful."
print(f"๐ Text: {text}")
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs.input_ids
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(f"๐ค Tokens: {tokens}")
print()
# Test different MLM probabilities
mlm_probs = [0.1, 0.15, 0.3, 0.5, 0.8]
for mlm_prob in mlm_probs:
print(f"๐ฏ MLM Probability: {mlm_prob}")
# Simulate the analysis process
seq_length = input_ids.size(1)
special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
# Count how many tokens would be analyzed
analyzed_count = 0
analyzed_tokens = []
torch.manual_seed(42) # For reproducible results
for i in range(seq_length):
token = tokens[i]
if input_ids[0, i].item() not in special_token_ids:
if torch.rand(1).item() < mlm_prob:
analyzed_count += 1
analyzed_tokens.append(f"'{token}'")
total_content_tokens = sum(1 for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids)
print(f" ๐ Analyzed: {analyzed_count}/{total_content_tokens} content tokens ({analyzed_count/total_content_tokens*100:.1f}%)")
print(f" ๐ฏ Analyzed tokens: {', '.join(analyzed_tokens[:5])}" + (f" + {len(analyzed_tokens)-5} more" if len(analyzed_tokens) > 5 else ""))
print()
def simulate_perplexity_calculation():
"""Simulate how different MLM probabilities affect perplexity calculation"""
print("๐งฎ Perplexity Calculation Simulation")
print("=" * 60)
# Load model
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
text = "Machine learning is transforming artificial intelligence rapidly."
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs.input_ids
print(f"๐ Text: {text}")
print(f"๐ค Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
print()
mlm_probs = [0.15, 0.3, 0.5]
for mlm_prob in mlm_probs:
print(f"๐ญ MLM Probability: {mlm_prob}")
# Simulate multiple iterations
iteration_results = []
for iteration in range(3):
# Simulate masking
masked_input_ids = input_ids.clone()
original_tokens = input_ids.clone()
seq_length = input_ids.size(1)
mask_indices = []
special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
torch.manual_seed(42 + iteration) # Different seed per iteration
for i in range(seq_length):
if input_ids[0, i].item() not in special_token_ids:
if torch.rand(1).item() < mlm_prob:
mask_indices.append(i)
masked_input_ids[0, i] = tokenizer.mask_token_id
if not mask_indices:
# Ensure at least one token is masked
non_special_indices = [i for i in range(seq_length)
if input_ids[0, i].item() not in special_token_ids]
if non_special_indices:
mask_idx = torch.randint(0, len(non_special_indices), (1,)).item()
mask_indices = [non_special_indices[mask_idx]]
masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id
# Calculate pseudo-perplexity for masked tokens
with torch.no_grad():
outputs = model(masked_input_ids)
predictions = outputs.logits
masked_token_losses = []
masked_tokens = []
for idx in mask_indices:
target_id = original_tokens[0, idx]
pred_scores = predictions[0, idx]
prob = torch.softmax(pred_scores, dim=-1)[target_id]
loss = -torch.log(prob + 1e-10)
masked_token_losses.append(loss.item())
token = tokenizer.convert_ids_to_tokens([target_id])[0]
masked_tokens.append(token)
if masked_token_losses:
avg_loss = sum(masked_token_losses) / len(masked_token_losses)
perplexity = torch.exp(torch.tensor(avg_loss)).item()
iteration_results.append(perplexity)
print(f" Iteration {iteration + 1}: {len(mask_indices)} tokens masked")
print(f" Masked: {', '.join(masked_tokens[:3])}" + (f" + {len(masked_tokens)-3} more" if len(masked_tokens) > 3 else ""))
print(f" Pseudo-perplexity: {perplexity:.2f}")
if iteration_results:
avg_perplexity = sum(iteration_results) / len(iteration_results)
print(f" ๐ Average pseudo-perplexity: {avg_perplexity:.2f}")
print()
def explain_mlm_probability():
"""Explain what MLM probability actually does"""
print("๐ก Understanding MLM Probability")
print("=" * 60)
print("""
๐ญ **What is MLM Probability?**
MLM (Masked Language Modeling) probability controls what fraction of tokens
get randomly selected for detailed perplexity analysis.
๐ **How it works:**
โข Low MLM prob (0.15): Analyzes ~15% of tokens randomly
โข High MLM prob (0.5): Analyzes ~50% of tokens randomly
โข This affects both the average perplexity AND the visualization
๐ฏ **Why it matters:**
โข Higher MLM prob = More tokens analyzed = More complete picture
โข Lower MLM prob = Fewer tokens analyzed = Faster but less comprehensive
โข The randomness simulates real MLM training conditions
๐ **Visual Effect:**
โข Analyzed tokens: Colored by their actual perplexity
โข Non-analyzed tokens: Shown in gray (baseline)
โข Try 0.15 vs 0.5 to see the difference!
โ๏ธ **Trade-offs:**
โข MLM 0.15: Fast, matches BERT training, but sparse analysis
โข MLM 0.5: Slower, more comprehensive, but artificial
โข MLM 0.8: Very slow, nearly complete, but unrealistic
""")
def main():
"""Run MLM probability demonstration"""
try:
explain_mlm_probability()
demo_mlm_probability_effect()
simulate_perplexity_calculation()
print("๐ MLM Probability Demo Complete!")
print("๐ก Now try the app with different MLM probabilities:")
print(" โข Use 0.15 for standard analysis")
print(" โข Use 0.5 for more comprehensive analysis")
print(" โข Watch how the visualization changes!")
except Exception as e:
print(f"โ Demo failed: {e}")
print("๐ก Make sure you have transformers installed: pip install transformers")
if __name__ == "__main__":
main()
|