Bram van Es commited on
Commit
ef5aa3c
Β·
1 Parent(s): b802cbd

Add application file

Browse files
Files changed (1) hide show
  1. app.py +484 -0
app.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
5
+ import numpy as np
6
+ import pandas as pd
7
+ import spacy
8
+ from spacy import displacy
9
+ import math
10
+ import warnings
11
+ try:
12
+ from config import DEFAULT_MODELS, MODEL_SETTINGS, VIZ_SETTINGS, PROCESSING_SETTINGS, UI_SETTINGS, ERROR_MESSAGES
13
+ except ImportError:
14
+ # Fallback configuration if config.py is not available
15
+ DEFAULT_MODELS = {
16
+ "decoder": ["gpt2", "distilgpt2"],
17
+ "encoder": ["bert-base-uncased", "distilbert-base-uncased"]
18
+ }
19
+ MODEL_SETTINGS = {"max_length": 512}
20
+ VIZ_SETTINGS = {
21
+ "max_perplexity_display": 100.0,
22
+ "color_scheme": {
23
+ "high_perplexity": {"r": 255, "g": 0, "b": 50},
24
+ "low_perplexity": {"r": 0, "g": 255, "b": 50}
25
+ },
26
+ "displacy_options": {"ents": ["PP"], "colors": {}}
27
+ }
28
+ PROCESSING_SETTINGS = {
29
+ "default_iterations": 1,
30
+ "max_iterations": 10,
31
+ "default_mlm_probability": 0.15,
32
+ "min_mlm_probability": 0.1,
33
+ "max_mlm_probability": 0.5,
34
+ "epsilon": 1e-10
35
+ }
36
+ UI_SETTINGS = {
37
+ "title": "πŸ“ˆ Perplexity Viewer",
38
+ "description": "Visualize per-token perplexity using color gradients.",
39
+ "examples": [
40
+ {"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder", "iterations": 1, "mlm_prob": 0.15},
41
+ {"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder", "iterations": 1, "mlm_prob": 0.15}
42
+ ]
43
+ }
44
+ ERROR_MESSAGES = {
45
+ "empty_text": "Please enter some text to analyze.",
46
+ "model_load_error": "Error loading model {model_name}: {error}",
47
+ "processing_error": "Error processing text: {error}"
48
+ }
49
+ warnings.filterwarnings("ignore")
50
+
51
+ # Global variables to cache models
52
+ cached_models = {}
53
+ cached_tokenizers = {}
54
+
55
+ def load_model_and_tokenizer(model_name, model_type):
56
+ """Load and cache model and tokenizer"""
57
+ cache_key = f"{model_name}_{model_type}"
58
+
59
+ if cache_key not in cached_models:
60
+ try:
61
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+
63
+ # Add pad token if it doesn't exist
64
+ if tokenizer.pad_token is None:
65
+ tokenizer.pad_token = tokenizer.eos_token
66
+
67
+ if model_type == "decoder":
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_name,
70
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
71
+ device_map="auto" if torch.cuda.is_available() else None,
72
+ trust_remote_code=True
73
+ )
74
+ else: # encoder
75
+ model = AutoModelForMaskedLM.from_pretrained(
76
+ model_name,
77
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
78
+ device_map="auto" if torch.cuda.is_available() else None,
79
+ trust_remote_code=True
80
+ )
81
+
82
+ model.eval() # Set to evaluation mode
83
+ cached_models[cache_key] = model
84
+ cached_tokenizers[cache_key] = tokenizer
85
+
86
+ return model, tokenizer
87
+ except Exception as e:
88
+ raise gr.Error(ERROR_MESSAGES["model_load_error"].format(model_name=model_name, error=str(e)))
89
+
90
+ return cached_models[cache_key], cached_tokenizers[cache_key]
91
+
92
+ def calculate_decoder_perplexity(text, model, tokenizer, iterations=1):
93
+ """Calculate perplexity for decoder models (like GPT)"""
94
+ device = next(model.parameters()).device
95
+
96
+ perplexities = []
97
+
98
+ for iteration in range(iterations):
99
+ # Tokenize the text
100
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"])
101
+ input_ids = inputs.input_ids.to(device)
102
+
103
+ if input_ids.size(1) < 2:
104
+ raise gr.Error("Text is too short for perplexity calculation.")
105
+
106
+ with torch.no_grad():
107
+ outputs = model(input_ids, labels=input_ids)
108
+ loss = outputs.loss
109
+ perplexity = torch.exp(loss).item()
110
+ perplexities.append(perplexity)
111
+
112
+ # Get token-level perplexities for the last iteration
113
+ with torch.no_grad():
114
+ outputs = model(input_ids)
115
+ logits = outputs.logits
116
+
117
+ # Shift logits and labels for next token prediction
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = input_ids[..., 1:].contiguous()
120
+
121
+ # Calculate per-token losses
122
+ loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
123
+ token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
124
+ token_perplexities = torch.exp(token_losses).cpu().numpy()
125
+
126
+ # Get tokens (excluding the first one since we predict next tokens)
127
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:])
128
+
129
+ # Clean up tokens for display
130
+ cleaned_tokens = []
131
+ for token in tokens:
132
+ if token.startswith('Δ '):
133
+ cleaned_tokens.append(token[1:]) # Remove Δ  prefix
134
+ elif token.startswith('##'):
135
+ cleaned_tokens.append(token[2:]) # Remove ## prefix
136
+ else:
137
+ cleaned_tokens.append(token)
138
+
139
+ return np.mean(perplexities), cleaned_tokens, token_perplexities
140
+
141
+ def calculate_encoder_perplexity(text, model, tokenizer, mlm_probability=0.15, iterations=1):
142
+ """Calculate pseudo-perplexity for encoder models (like BERT) using MLM"""
143
+ device = next(model.parameters()).device
144
+
145
+ perplexities = []
146
+
147
+ for iteration in range(iterations):
148
+ # Tokenize the text
149
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"])
150
+ input_ids = inputs.input_ids.to(device)
151
+
152
+ if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP]
153
+ raise gr.Error("Text is too short for MLM perplexity calculation.")
154
+
155
+ # Create a copy for masking
156
+ masked_input_ids = input_ids.clone()
157
+ original_tokens = input_ids.clone()
158
+
159
+ # Randomly mask tokens (excluding special tokens)
160
+ seq_length = input_ids.size(1)
161
+ mask_indices = []
162
+ special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
163
+
164
+ for i in range(seq_length):
165
+ if input_ids[0, i].item() not in special_token_ids:
166
+ if torch.rand(1).item() < mlm_probability:
167
+ mask_indices.append(i)
168
+ masked_input_ids[0, i] = tokenizer.mask_token_id
169
+
170
+ if not mask_indices:
171
+ # If no tokens were masked, mask at least one non-special token
172
+ non_special_indices = [i for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids]
173
+ if non_special_indices:
174
+ mask_idx = torch.randint(0, len(non_special_indices), (1,)).item()
175
+ mask_indices = [non_special_indices[mask_idx]]
176
+ masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id
177
+
178
+ with torch.no_grad():
179
+ outputs = model(masked_input_ids)
180
+ predictions = outputs.logits
181
+
182
+ # Calculate perplexity for masked tokens
183
+ masked_token_losses = []
184
+ for idx in mask_indices:
185
+ target_id = original_tokens[0, idx]
186
+ pred_scores = predictions[0, idx]
187
+ prob = F.softmax(pred_scores, dim=-1)[target_id]
188
+ loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"])
189
+ masked_token_losses.append(loss.item())
190
+
191
+ if masked_token_losses:
192
+ avg_loss = np.mean(masked_token_losses)
193
+ perplexity = math.exp(avg_loss)
194
+ perplexities.append(perplexity)
195
+
196
+ # Calculate per-token pseudo-perplexity for visualization
197
+ with torch.no_grad():
198
+ token_perplexities = []
199
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
200
+ special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
201
+
202
+ for i in range(len(tokens)):
203
+ if input_ids[0, i].item() in special_token_ids:
204
+ token_perplexities.append(1.0) # Low perplexity for special tokens
205
+ else:
206
+ masked_input = input_ids.clone()
207
+ original_token_id = input_ids[0, i]
208
+ masked_input[0, i] = tokenizer.mask_token_id
209
+
210
+ outputs = model(masked_input)
211
+ predictions = outputs.logits[0, i]
212
+ prob = F.softmax(predictions, dim=-1)[original_token_id]
213
+ token_perplexity = 1.0 / (prob.item() + PROCESSING_SETTINGS["epsilon"])
214
+ token_perplexities.append(token_perplexity)
215
+
216
+ # Clean up tokens for display
217
+ cleaned_tokens = []
218
+ for token in tokens:
219
+ if token.startswith('##'):
220
+ cleaned_tokens.append(token[2:])
221
+ else:
222
+ cleaned_tokens.append(token)
223
+
224
+ return np.mean(perplexities) if perplexities else float('inf'), cleaned_tokens, np.array(token_perplexities)
225
+
226
+ def create_visualization(tokens, perplexities):
227
+ """Create displaCy visualization with color-coded perplexities"""
228
+ if len(tokens) == 0:
229
+ return "<p>No tokens to visualize.</p>"
230
+
231
+ # Cap perplexities for better visualization
232
+ max_perplexity = min(np.max(perplexities), VIZ_SETTINGS["max_perplexity_display"])
233
+
234
+ # Normalize perplexities to 0-1 range for color mapping
235
+ normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1)
236
+
237
+ # Create entities for displaCy
238
+ entities = []
239
+ text_parts = []
240
+ current_pos = 0
241
+
242
+ for i, (token, perp, norm_perp) in enumerate(zip(tokens, perplexities, normalized_perplexities)):
243
+ # Skip empty tokens
244
+ if not token.strip():
245
+ continue
246
+
247
+ # Clean token for display
248
+ clean_token = token.replace("</w>", "").strip()
249
+ if not clean_token:
250
+ continue
251
+
252
+ # Add space before token if it's not the first one and doesn't start with punctuation
253
+ if i > 0 and not clean_token[0] in ".,!?;:":
254
+ text_parts.append(" ")
255
+ current_pos += 1
256
+
257
+ text_parts.append(clean_token)
258
+
259
+ # Map perplexity to color
260
+ high_color = VIZ_SETTINGS["color_scheme"]["high_perplexity"]
261
+ low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"]
262
+
263
+ red = int(high_color["r"] * norm_perp + low_color["r"] * (1 - norm_perp))
264
+ green = int(high_color["g"] * norm_perp + low_color["g"] * (1 - norm_perp))
265
+ blue = int(high_color["b"] * norm_perp + low_color["b"] * (1 - norm_perp))
266
+
267
+ color = f"rgb({red}, {green}, {blue})"
268
+
269
+ entities.append({
270
+ "start": current_pos,
271
+ "end": current_pos + len(clean_token),
272
+ "label": f"{perp:.2f}",
273
+ "color": color
274
+ })
275
+
276
+ current_pos += len(clean_token)
277
+
278
+ # Join text parts
279
+ text = "".join(text_parts)
280
+
281
+ if not entities:
282
+ return "<p>No valid tokens found for visualization.</p>"
283
+
284
+ # Create displaCy data structure
285
+ doc_data = {
286
+ "text": text,
287
+ "ents": entities,
288
+ "title": "Per-token Perplexity Visualization"
289
+ }
290
+
291
+ try:
292
+ # Generate HTML
293
+ html = displacy.render(doc_data, style="ent", manual=True, options=VIZ_SETTINGS["displacy_options"])
294
+ return html
295
+ except Exception as e:
296
+ return f"<p>Error creating visualization: {str(e)}</p>"
297
+
298
+ def process_text(text, model_name, model_type, iterations, mlm_probability):
299
+ """Main processing function"""
300
+ if not text.strip():
301
+ return ERROR_MESSAGES["empty_text"], "", pd.DataFrame()
302
+
303
+ try:
304
+ # Validate inputs
305
+ iterations = max(1, min(iterations, PROCESSING_SETTINGS["max_iterations"]))
306
+ mlm_probability = max(PROCESSING_SETTINGS["min_mlm_probability"],
307
+ min(mlm_probability, PROCESSING_SETTINGS["max_mlm_probability"]))
308
+
309
+ # Load model and tokenizer
310
+ model, tokenizer = load_model_and_tokenizer(model_name, model_type)
311
+
312
+ # Calculate perplexity
313
+ if model_type == "decoder":
314
+ avg_perplexity, tokens, token_perplexities = calculate_decoder_perplexity(
315
+ text, model, tokenizer, iterations
316
+ )
317
+ else: # encoder
318
+ avg_perplexity, tokens, token_perplexities = calculate_encoder_perplexity(
319
+ text, model, tokenizer, mlm_probability, iterations
320
+ )
321
+
322
+ # Create visualization
323
+ viz_html = create_visualization(tokens, token_perplexities)
324
+
325
+ # Create summary
326
+ summary = f"""
327
+ ### Analysis Results
328
+
329
+ **Model:** `{model_name}`
330
+ **Model Type:** {model_type.title()}
331
+ **Average Perplexity:** {avg_perplexity:.4f}
332
+ **Number of Tokens:** {len(tokens)}
333
+ **Iterations:** {iterations}
334
+ """
335
+
336
+ if model_type == "encoder":
337
+ summary += f" \n**MLM Probability:** {mlm_probability}"
338
+
339
+ # Create detailed results table
340
+ df = pd.DataFrame({
341
+ 'Token': tokens,
342
+ 'Perplexity': [f"{p:.4f}" for p in token_perplexities]
343
+ })
344
+
345
+ return summary, viz_html, df
346
+
347
+ except Exception as e:
348
+ error_msg = ERROR_MESSAGES["processing_error"].format(error=str(e))
349
+ return error_msg, "", pd.DataFrame()
350
+
351
+ # Create Gradio interface
352
+ with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo:
353
+ gr.Markdown(f"# {UI_SETTINGS['title']}")
354
+ gr.Markdown(UI_SETTINGS["description"])
355
+
356
+ with gr.Row():
357
+ with gr.Column(scale=2):
358
+ text_input = gr.Textbox(
359
+ label="Input Text",
360
+ placeholder="Enter the text you want to analyze...",
361
+ lines=6,
362
+ max_lines=10
363
+ )
364
+
365
+ with gr.Row():
366
+ model_name = gr.Dropdown(
367
+ label="Model Name",
368
+ choices=DEFAULT_MODELS["decoder"] + DEFAULT_MODELS["encoder"],
369
+ value="gpt2",
370
+ allow_custom_value=True,
371
+ info="Select a model or enter a custom HuggingFace model name"
372
+ )
373
+
374
+ model_type = gr.Radio(
375
+ label="Model Type",
376
+ choices=["decoder", "encoder"],
377
+ value="decoder",
378
+ info="Decoder for causal LM, Encoder for masked LM"
379
+ )
380
+
381
+ with gr.Row():
382
+ iterations = gr.Slider(
383
+ label="Iterations",
384
+ minimum=1,
385
+ maximum=PROCESSING_SETTINGS["max_iterations"],
386
+ value=PROCESSING_SETTINGS["default_iterations"],
387
+ step=1,
388
+ info="Number of iterations to average over"
389
+ )
390
+
391
+ mlm_probability = gr.Slider(
392
+ label="MLM Probability",
393
+ minimum=PROCESSING_SETTINGS["min_mlm_probability"],
394
+ maximum=PROCESSING_SETTINGS["max_mlm_probability"],
395
+ value=PROCESSING_SETTINGS["default_mlm_probability"],
396
+ step=0.05,
397
+ visible=False,
398
+ info="Probability of masking tokens (encoder models only)"
399
+ )
400
+
401
+ analyze_btn = gr.Button("πŸ” Analyze Perplexity", variant="primary", size="lg")
402
+
403
+ with gr.Column(scale=3):
404
+ summary_output = gr.Markdown(label="Summary")
405
+ viz_output = gr.HTML(label="Perplexity Visualization")
406
+
407
+ # Full-width table
408
+ with gr.Row():
409
+ table_output = gr.Dataframe(
410
+ label="Detailed Token Results",
411
+ interactive=False,
412
+ wrap=True
413
+ )
414
+
415
+ # Update model dropdown based on type selection
416
+ def update_model_choices(model_type):
417
+ return gr.update(choices=DEFAULT_MODELS[model_type], value=DEFAULT_MODELS[model_type][0])
418
+
419
+ # Show/hide MLM probability based on model type
420
+ def toggle_mlm_visibility(model_type):
421
+ return gr.update(visible=(model_type == "encoder"))
422
+
423
+ model_type.change(
424
+ fn=lambda mt: [update_model_choices(mt), toggle_mlm_visibility(mt)],
425
+ inputs=[model_type],
426
+ outputs=[model_name, mlm_probability]
427
+ )
428
+
429
+ # Set up the analysis function
430
+ analyze_btn.click(
431
+ fn=process_text,
432
+ inputs=[text_input, model_name, model_type, iterations, mlm_probability],
433
+ outputs=[summary_output, viz_output, table_output]
434
+ )
435
+
436
+ # Add examples
437
+ with gr.Accordion("πŸ“ Example Texts", open=False):
438
+ examples_data = [
439
+ [ex["text"], ex["model"], ex["type"], ex["iterations"], ex["mlm_prob"]]
440
+ for ex in UI_SETTINGS["examples"]
441
+ ]
442
+
443
+ gr.Examples(
444
+ examples=examples_data,
445
+ inputs=[text_input, model_name, model_type, iterations, mlm_probability],
446
+ outputs=[summary_output, viz_output, table_output],
447
+ fn=process_text,
448
+ cache_examples=False,
449
+ label="Click on an example to try it out:"
450
+ )
451
+
452
+ # Add footer with information
453
+ gr.Markdown("""
454
+ ---
455
+
456
+ ### πŸ“Š How it works:
457
+
458
+ - **Decoder Models** (GPT, etc.): Calculate true perplexity by measuring how well the model predicts the next token
459
+ - **Encoder Models** (BERT, etc.): Calculate pseudo-perplexity using masked language modeling (MLM)
460
+ - **Color Coding**: Red = High perplexity (uncertain), Green = Low perplexity (confident)
461
+
462
+ ### ⚠️ Notes:
463
+ - First model load may take some time
464
+ - Models are cached after first use
465
+ - Very long texts are truncated to 512 tokens
466
+ - GPU acceleration is used when available
467
+ """)
468
+
469
+ if __name__ == "__main__":
470
+ try:
471
+ demo.launch(
472
+ server_name="0.0.0.0",
473
+ server_port=7860,
474
+ show_api=False
475
+ )
476
+ except Exception as e:
477
+ print(f"❌ Failed to launch app: {e}")
478
+ print("πŸ’‘ Try running with: python run.py")
479
+ # Fallback to basic launch
480
+ try:
481
+ demo.launch()
482
+ except Exception as fallback_error:
483
+ print(f"❌ Fallback launch also failed: {fallback_error}")
484
+ print("πŸ’‘ Try updating Gradio: pip install --upgrade gradio")