File size: 7,671 Bytes
ef12530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#!/usr/bin/env python3
"""
Demo script showing how MLM probability affects encoder model analysis
"""

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import warnings
warnings.filterwarnings("ignore")

def demo_mlm_probability_effect():
    """Demonstrate how MLM probability affects the analysis"""
    print("๐ŸŽญ MLM Probability Effect Demo")
    print("=" * 60)

    # Load a BERT model
    model_name = "distilbert-base-uncased"
    print(f"Loading {model_name}...")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForMaskedLM.from_pretrained(model_name)
    model.eval()

    # Test text
    text = "The capital of France is Paris and it is beautiful."
    print(f"๐Ÿ“ Text: {text}")

    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    input_ids = inputs.input_ids
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    print(f"๐Ÿ”ค Tokens: {tokens}")
    print()

    # Test different MLM probabilities
    mlm_probs = [0.1, 0.15, 0.3, 0.5, 0.8]

    for mlm_prob in mlm_probs:
        print(f"๐ŸŽฏ MLM Probability: {mlm_prob}")

        # Simulate the analysis process
        seq_length = input_ids.size(1)
        special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}

        # Count how many tokens would be analyzed
        analyzed_count = 0
        analyzed_tokens = []

        torch.manual_seed(42)  # For reproducible results

        for i in range(seq_length):
            token = tokens[i]
            if input_ids[0, i].item() not in special_token_ids:
                if torch.rand(1).item() < mlm_prob:
                    analyzed_count += 1
                    analyzed_tokens.append(f"'{token}'")

        total_content_tokens = sum(1 for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids)

        print(f"   ๐Ÿ“Š Analyzed: {analyzed_count}/{total_content_tokens} content tokens ({analyzed_count/total_content_tokens*100:.1f}%)")
        print(f"   ๐ŸŽฏ Analyzed tokens: {', '.join(analyzed_tokens[:5])}" + (f" + {len(analyzed_tokens)-5} more" if len(analyzed_tokens) > 5 else ""))
        print()

def simulate_perplexity_calculation():
    """Simulate how different MLM probabilities affect perplexity calculation"""
    print("๐Ÿงฎ Perplexity Calculation Simulation")
    print("=" * 60)

    # Load model
    model_name = "distilbert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForMaskedLM.from_pretrained(model_name)
    model.eval()

    text = "Machine learning is transforming artificial intelligence rapidly."
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs.input_ids

    print(f"๐Ÿ“ Text: {text}")
    print(f"๐Ÿ”ค Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
    print()

    mlm_probs = [0.15, 0.3, 0.5]

    for mlm_prob in mlm_probs:
        print(f"๐ŸŽญ MLM Probability: {mlm_prob}")

        # Simulate multiple iterations
        iteration_results = []

        for iteration in range(3):
            # Simulate masking
            masked_input_ids = input_ids.clone()
            original_tokens = input_ids.clone()
            seq_length = input_ids.size(1)

            mask_indices = []
            special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}

            torch.manual_seed(42 + iteration)  # Different seed per iteration

            for i in range(seq_length):
                if input_ids[0, i].item() not in special_token_ids:
                    if torch.rand(1).item() < mlm_prob:
                        mask_indices.append(i)
                        masked_input_ids[0, i] = tokenizer.mask_token_id

            if not mask_indices:
                # Ensure at least one token is masked
                non_special_indices = [i for i in range(seq_length)
                                     if input_ids[0, i].item() not in special_token_ids]
                if non_special_indices:
                    mask_idx = torch.randint(0, len(non_special_indices), (1,)).item()
                    mask_indices = [non_special_indices[mask_idx]]
                    masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id

            # Calculate pseudo-perplexity for masked tokens
            with torch.no_grad():
                outputs = model(masked_input_ids)
                predictions = outputs.logits

                masked_token_losses = []
                masked_tokens = []

                for idx in mask_indices:
                    target_id = original_tokens[0, idx]
                    pred_scores = predictions[0, idx]
                    prob = torch.softmax(pred_scores, dim=-1)[target_id]
                    loss = -torch.log(prob + 1e-10)
                    masked_token_losses.append(loss.item())

                    token = tokenizer.convert_ids_to_tokens([target_id])[0]
                    masked_tokens.append(token)

                if masked_token_losses:
                    avg_loss = sum(masked_token_losses) / len(masked_token_losses)
                    perplexity = torch.exp(torch.tensor(avg_loss)).item()
                    iteration_results.append(perplexity)

                    print(f"   Iteration {iteration + 1}: {len(mask_indices)} tokens masked")
                    print(f"      Masked: {', '.join(masked_tokens[:3])}" + (f" + {len(masked_tokens)-3} more" if len(masked_tokens) > 3 else ""))
                    print(f"      Pseudo-perplexity: {perplexity:.2f}")

        if iteration_results:
            avg_perplexity = sum(iteration_results) / len(iteration_results)
            print(f"   ๐Ÿ“Š Average pseudo-perplexity: {avg_perplexity:.2f}")
        print()

def explain_mlm_probability():
    """Explain what MLM probability actually does"""
    print("๐Ÿ’ก Understanding MLM Probability")
    print("=" * 60)

    print("""
๐ŸŽญ **What is MLM Probability?**
   MLM (Masked Language Modeling) probability controls what fraction of tokens
   get randomly selected for detailed perplexity analysis.

๐Ÿ“Š **How it works:**
   โ€ข Low MLM prob (0.15): Analyzes ~15% of tokens randomly
   โ€ข High MLM prob (0.5):  Analyzes ~50% of tokens randomly
   โ€ข This affects both the average perplexity AND the visualization

๐ŸŽฏ **Why it matters:**
   โ€ข Higher MLM prob = More tokens analyzed = More complete picture
   โ€ข Lower MLM prob = Fewer tokens analyzed = Faster but less comprehensive
   โ€ข The randomness simulates real MLM training conditions

๐ŸŒˆ **Visual Effect:**
   โ€ข Analyzed tokens: Colored by their actual perplexity
   โ€ข Non-analyzed tokens: Shown in gray (baseline)
   โ€ข Try 0.15 vs 0.5 to see the difference!

โš–๏ธ **Trade-offs:**
   โ€ข MLM 0.15: Fast, matches BERT training, but sparse analysis
   โ€ข MLM 0.5:  Slower, more comprehensive, but artificial
   โ€ข MLM 0.8:  Very slow, nearly complete, but unrealistic
""")

def main():
    """Run MLM probability demonstration"""
    try:
        explain_mlm_probability()
        demo_mlm_probability_effect()
        simulate_perplexity_calculation()

        print("๐ŸŽ‰ MLM Probability Demo Complete!")
        print("๐Ÿ’ก Now try the app with different MLM probabilities:")
        print("   โ€ข Use 0.15 for standard analysis")
        print("   โ€ข Use 0.5 for more comprehensive analysis")
        print("   โ€ข Watch how the visualization changes!")

    except Exception as e:
        print(f"โŒ Demo failed: {e}")
        print("๐Ÿ’ก Make sure you have transformers installed: pip install transformers")

if __name__ == "__main__":
    main()