Bram van Es commited on
Commit
80323f9
Β·
1 Parent(s): 738aa6e

first push

Browse files
Files changed (4) hide show
  1. app.py +540 -0
  2. launch.py +48 -0
  3. requirements.txt +7 -0
  4. run.py +181 -0
app.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
5
+ import numpy as np
6
+ import pandas as pd
7
+ import spacy
8
+ from spacy import displacy
9
+ import math
10
+ import warnings
11
+ try:
12
+ from config import DEFAULT_MODELS, MODEL_SETTINGS, VIZ_SETTINGS, PROCESSING_SETTINGS, UI_SETTINGS, ERROR_MESSAGES
13
+ except ImportError:
14
+ # Fallback configuration if config.py is not available
15
+ DEFAULT_MODELS = {
16
+ "decoder": ["gpt2", "distilgpt2"],
17
+ "encoder": ["bert-base-uncased", "distilbert-base-uncased"]
18
+ }
19
+ MODEL_SETTINGS = {"max_length": 512}
20
+ VIZ_SETTINGS = {
21
+ "max_perplexity_display": 50.0,
22
+ "color_scheme": {
23
+ "low_perplexity": {"r": 46, "g": 204, "b": 113},
24
+ "medium_perplexity": {"r": 241, "g": 196, "b": 15},
25
+ "high_perplexity": {"r": 231, "g": 76, "b": 60},
26
+ "background_alpha": 0.7,
27
+ "border_alpha": 0.9
28
+ },
29
+ "thresholds": {
30
+ "low_threshold": 0.3,
31
+ "high_threshold": 0.7
32
+ },
33
+ "displacy_options": {"ents": ["PP"], "colors": {}}
34
+ }
35
+ PROCESSING_SETTINGS = {
36
+ "epsilon": 1e-10,
37
+ "default_mask_probability": 0.15,
38
+ "min_mask_probability": 0.05,
39
+ "max_mask_probability": 0.5,
40
+ "default_min_samples": 10,
41
+ "min_samples_range": (5, 50)
42
+ }
43
+ UI_SETTINGS = {
44
+ "title": "πŸ“ˆ Perplexity Viewer",
45
+ "description": "Visualize per-token perplexity using color gradients.",
46
+ "examples": [
47
+ {"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder", "mask_prob": 0.15, "min_samples": 10},
48
+ {"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder", "mask_prob": 0.15, "min_samples": 10},
49
+ {"text": "Quantum entanglement defies classical physics intuition completely.", "model": "distilgpt2", "type": "decoder", "mask_prob": 0.15, "min_samples": 10},
50
+ {"text": "Machine learning requires large datasets for training.", "model": "distilbert-base-uncased", "type": "encoder", "mask_prob": 0.2, "min_samples": 15},
51
+ {"text": "Artificial intelligence transforms modern computing paradigms.", "model": "bert-base-uncased", "type": "encoder", "mask_prob": 0.1, "min_samples": 20}
52
+ ]
53
+ }
54
+ ERROR_MESSAGES = {
55
+ "empty_text": "Please enter some text to analyze.",
56
+ "model_load_error": "Error loading model {model_name}: {error}",
57
+ "processing_error": "Error processing text: {error}"
58
+ }
59
+ warnings.filterwarnings("ignore")
60
+
61
+ # Global variables to cache models
62
+ cached_models = {}
63
+ cached_tokenizers = {}
64
+
65
+ def load_model_and_tokenizer(model_name, model_type):
66
+ """Load and cache model and tokenizer"""
67
+ cache_key = f"{model_name}_{model_type}"
68
+
69
+ if cache_key not in cached_models:
70
+ try:
71
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
72
+
73
+ # Add pad token if it doesn't exist
74
+ if tokenizer.pad_token is None:
75
+ tokenizer.pad_token = tokenizer.eos_token
76
+
77
+ if model_type == "decoder":
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ model_name,
80
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
81
+ device_map="auto" if torch.cuda.is_available() else None,
82
+ trust_remote_code=True
83
+ )
84
+ else: # encoder
85
+ model = AutoModelForMaskedLM.from_pretrained(
86
+ model_name,
87
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
88
+ device_map="auto" if torch.cuda.is_available() else None,
89
+ trust_remote_code=True
90
+ )
91
+
92
+ model.eval() # Set to evaluation mode
93
+ cached_models[cache_key] = model
94
+ cached_tokenizers[cache_key] = tokenizer
95
+
96
+ return model, tokenizer
97
+ except Exception as e:
98
+ raise gr.Error(ERROR_MESSAGES["model_load_error"].format(model_name=model_name, error=str(e)))
99
+
100
+ return cached_models[cache_key], cached_tokenizers[cache_key]
101
+
102
+ def calculate_decoder_perplexity(text, model, tokenizer):
103
+ """Calculate perplexity for decoder models (like GPT)"""
104
+ device = next(model.parameters()).device
105
+
106
+ # Tokenize the text
107
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"])
108
+ input_ids = inputs.input_ids.to(device)
109
+
110
+ if input_ids.size(1) < 2:
111
+ raise gr.Error("Text is too short for perplexity calculation.")
112
+
113
+ # Calculate overall perplexity
114
+ with torch.no_grad():
115
+ outputs = model(input_ids, labels=input_ids)
116
+ loss = outputs.loss
117
+ perplexity = torch.exp(loss).item()
118
+
119
+ # Get token-level perplexities
120
+ with torch.no_grad():
121
+ outputs = model(input_ids)
122
+ logits = outputs.logits
123
+
124
+ # Shift logits and labels for next token prediction
125
+ shift_logits = logits[..., :-1, :].contiguous()
126
+ shift_labels = input_ids[..., 1:].contiguous()
127
+
128
+ # Calculate per-token losses
129
+ loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
130
+ token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
131
+ token_perplexities = torch.exp(token_losses).cpu().numpy()
132
+
133
+ # Get tokens (excluding the first one since we predict next tokens)
134
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:])
135
+
136
+ # Clean up tokens for display
137
+ cleaned_tokens = []
138
+ for token in tokens:
139
+ if token.startswith('Δ '):
140
+ cleaned_tokens.append(token[1:]) # Remove Δ  prefix
141
+ elif token.startswith('##'):
142
+ cleaned_tokens.append(token[2:]) # Remove ## prefix
143
+ else:
144
+ cleaned_tokens.append(token)
145
+
146
+ return perplexity, cleaned_tokens, token_perplexities
147
+
148
+ def calculate_encoder_perplexity(text, model, tokenizer, mask_probability=0.15, min_samples_per_token=10):
149
+ """Calculate pseudo-perplexity for encoder models using statistical sampling with multiple token masking"""
150
+ device = next(model.parameters()).device
151
+
152
+ # Tokenize the text
153
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"])
154
+ input_ids = inputs.input_ids.to(device)
155
+
156
+ if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP]
157
+ raise gr.Error("Text is too short for MLM perplexity calculation.")
158
+
159
+ seq_length = input_ids.size(1)
160
+ special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
161
+
162
+ # Get content token indices (excluding special tokens)
163
+ content_token_indices = [i for i in range(seq_length)
164
+ if input_ids[0, i].item() not in special_token_ids]
165
+
166
+ if not content_token_indices:
167
+ raise gr.Error("No content tokens found for analysis.")
168
+
169
+ # Initialize storage for per-token perplexity samples
170
+ token_perplexity_samples = {idx: [] for idx in content_token_indices}
171
+
172
+ # Calculate overall average perplexity and collect samples
173
+ all_losses = []
174
+ max_iterations = min_samples_per_token * 50 # Safety limit to prevent infinite loops
175
+ iteration = 0
176
+
177
+ with torch.no_grad():
178
+ while iteration < max_iterations:
179
+ # Create a copy for masking
180
+ masked_input = input_ids.clone()
181
+ masked_indices = []
182
+
183
+ # Randomly mask tokens based on mask_probability
184
+ for idx in content_token_indices:
185
+ if torch.rand(1).item() < mask_probability:
186
+ masked_indices.append(idx)
187
+ masked_input[0, idx] = tokenizer.mask_token_id
188
+
189
+ # Skip if no tokens were masked
190
+ if not masked_indices:
191
+ iteration += 1
192
+ continue
193
+
194
+ # Get model predictions
195
+ outputs = model(masked_input)
196
+ predictions = outputs.logits
197
+
198
+ # Calculate perplexity for each masked token
199
+ for idx in masked_indices:
200
+ original_token_id = input_ids[0, idx]
201
+ pred_scores = predictions[0, idx]
202
+ prob = F.softmax(pred_scores, dim=-1)[original_token_id]
203
+ loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"])
204
+ perplexity = math.exp(loss.item())
205
+
206
+ # Store sample for this token
207
+ token_perplexity_samples[idx].append(perplexity)
208
+ all_losses.append(loss.item())
209
+
210
+ iteration += 1
211
+
212
+ # Check if we have enough samples for all tokens
213
+ min_samples_collected = min(len(samples) for samples in token_perplexity_samples.values())
214
+ if min_samples_collected >= min_samples_per_token:
215
+ break
216
+
217
+ # Calculate overall average perplexity
218
+ if all_losses:
219
+ avg_loss = np.mean(all_losses)
220
+ overall_perplexity = math.exp(avg_loss)
221
+ else:
222
+ overall_perplexity = float('inf')
223
+
224
+ # Calculate mean perplexity per token for visualization
225
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
226
+ token_perplexities = []
227
+
228
+ for i in range(len(tokens)):
229
+ if input_ids[0, i].item() in special_token_ids:
230
+ token_perplexities.append(1.0) # Low perplexity for special tokens
231
+ elif i in token_perplexity_samples and token_perplexity_samples[i]:
232
+ # Use mean of collected samples
233
+ token_perplexities.append(np.mean(token_perplexity_samples[i]))
234
+ else:
235
+ # Fallback if no samples collected (shouldn't happen with proper min_samples)
236
+ token_perplexities.append(2.0)
237
+
238
+ # Clean up tokens for display
239
+ cleaned_tokens = []
240
+ for token in tokens:
241
+ if token.startswith('##'):
242
+ cleaned_tokens.append(token[2:])
243
+ else:
244
+ cleaned_tokens.append(token)
245
+
246
+ return overall_perplexity, cleaned_tokens, np.array(token_perplexities)
247
+
248
+ def create_visualization(tokens, perplexities):
249
+ """Create custom HTML visualization with color-coded perplexities"""
250
+ if len(tokens) == 0:
251
+ return "<p>No tokens to visualize.</p>"
252
+
253
+ # Cap perplexities for better visualization
254
+ max_perplexity = min(np.max(perplexities), VIZ_SETTINGS["max_perplexity_display"])
255
+
256
+ # Normalize perplexities to 0-1 range for color mapping
257
+ normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1)
258
+
259
+ # Create HTML with inline styles for color coding
260
+ html_parts = [
261
+ '<div style="font-family: Arial, sans-serif; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background-color: #fafafa;">',
262
+ '<h3 style="margin-top: 0; color: #333;">Per-token Perplexity Visualization</h3>',
263
+ '<div style="margin-bottom: 15px;">',
264
+ '<span style="font-size: 12px; color: #666;">',
265
+ '🟒 Low perplexity (confident) β†’ 🟑 Medium β†’ πŸ”΄ High perplexity (uncertain)',
266
+ '</span>',
267
+ '</div>',
268
+ '<div style="line-height: 2.0;">'
269
+ ]
270
+
271
+ for i, (token, perp, norm_perp) in enumerate(zip(tokens, perplexities, normalized_perplexities)):
272
+ # Skip empty tokens
273
+ if not token.strip():
274
+ continue
275
+
276
+ # Clean token for display
277
+ clean_token = token.replace("</w>", "").replace("##", "").strip()
278
+ if not clean_token:
279
+ continue
280
+
281
+ # Add space before token if needed
282
+ if i > 0 and not clean_token[0] in ".,!?;:":
283
+ html_parts.append(" ")
284
+
285
+ # Get color thresholds from configuration
286
+ low_thresh = VIZ_SETTINGS.get("thresholds", {}).get("low_threshold", 0.3)
287
+ high_thresh = VIZ_SETTINGS.get("thresholds", {}).get("high_threshold", 0.7)
288
+
289
+ # Get colors from configuration
290
+ low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"]
291
+ med_color = VIZ_SETTINGS["color_scheme"]["medium_perplexity"]
292
+ high_color = VIZ_SETTINGS["color_scheme"]["high_perplexity"]
293
+
294
+ # Map perplexity to color using configuration
295
+ if norm_perp < low_thresh: # Low perplexity - green
296
+ # Interpolate between green and yellow
297
+ factor = norm_perp / low_thresh
298
+ red = int(low_color["r"] + factor * (med_color["r"] - low_color["r"]))
299
+ green = int(low_color["g"] + factor * (med_color["g"] - low_color["g"]))
300
+ blue = int(low_color["b"] + factor * (med_color["b"] - low_color["b"]))
301
+ elif norm_perp < high_thresh: # Medium perplexity - yellow/orange
302
+ # Interpolate between yellow and red
303
+ factor = (norm_perp - low_thresh) / (high_thresh - low_thresh)
304
+ red = int(med_color["r"] + factor * (high_color["r"] - med_color["r"]))
305
+ green = int(med_color["g"] + factor * (high_color["g"] - med_color["g"]))
306
+ blue = int(med_color["b"] + factor * (high_color["b"] - med_color["b"]))
307
+ else: # High perplexity - red
308
+ # Use high perplexity color, potentially darker for very high values
309
+ factor = min((norm_perp - high_thresh) / (1.0 - high_thresh), 1.0)
310
+ darken = 0.8 - (factor * 0.3) # Darken by up to 30%
311
+ red = int(high_color["r"] * darken)
312
+ green = int(high_color["g"] * darken)
313
+ blue = int(high_color["b"] * darken)
314
+
315
+ tooltip_text = f"Perplexity: {perp:.3f} (normalized: {norm_perp:.3f})"
316
+
317
+ # Clamp values
318
+ red = max(0, min(255, red))
319
+ green = max(0, min(255, green))
320
+ blue = max(0, min(255, blue))
321
+
322
+ # Get alpha values from configuration
323
+ bg_alpha = VIZ_SETTINGS["color_scheme"].get("background_alpha", 0.7)
324
+ border_alpha = VIZ_SETTINGS["color_scheme"].get("border_alpha", 0.9)
325
+
326
+ # Create colored span with tooltip
327
+ html_parts.append(
328
+ f'<span style="'
329
+ f'background-color: rgba({red}, {green}, {blue}, {bg_alpha}); '
330
+ f'color: #000; '
331
+ f'padding: 2px 4px; '
332
+ f'margin: 1px; '
333
+ f'border-radius: 3px; '
334
+ f'border: 1px solid rgba({red}, {green}, {blue}, {border_alpha}); '
335
+ f'font-weight: 500; '
336
+ f'cursor: help; '
337
+ f'display: inline-block;'
338
+ f'" title="{tooltip_text}">{clean_token}</span>'
339
+ )
340
+
341
+ html_parts.extend([
342
+ '</div>',
343
+ '<div style="margin-top: 15px; font-size: 12px; color: #666;">',
344
+ f'Max perplexity in visualization: {max_perplexity:.2f} | ',
345
+ f'Total tokens: {len(tokens)}',
346
+ '</div>',
347
+ '</div>'
348
+ ])
349
+
350
+ return "".join(html_parts)
351
+
352
+ def process_text(text, model_name, model_type, mask_probability=0.15, min_samples=10):
353
+ """Main processing function"""
354
+ if not text.strip():
355
+ return ERROR_MESSAGES["empty_text"], "", pd.DataFrame()
356
+
357
+ try:
358
+ # Load model and tokenizer
359
+ model, tokenizer = load_model_and_tokenizer(model_name, model_type)
360
+
361
+ # Calculate perplexity
362
+ if model_type == "decoder":
363
+ avg_perplexity, tokens, token_perplexities = calculate_decoder_perplexity(
364
+ text, model, tokenizer
365
+ )
366
+ sampling_info = ""
367
+ else: # encoder
368
+ avg_perplexity, tokens, token_perplexities = calculate_encoder_perplexity(
369
+ text, model, tokenizer, mask_probability, min_samples
370
+ )
371
+ sampling_info = f"**Mask Probability:** {mask_probability:.1%} \n**Min Samples per Token:** {min_samples} \n"
372
+
373
+ # Create visualization
374
+ viz_html = create_visualization(tokens, token_perplexities)
375
+
376
+ # Create summary
377
+ summary = f"""
378
+ ### Analysis Results
379
+
380
+ **Model:** `{model_name}`
381
+ **Model Type:** {model_type.title()}
382
+ **Average Perplexity:** {avg_perplexity:.4f}
383
+ **Number of Tokens:** {len(tokens)}
384
+ {sampling_info}"""
385
+
386
+ # Create detailed results table
387
+ df = pd.DataFrame({
388
+ 'Token': tokens,
389
+ 'Perplexity': [f"{p:.4f}" for p in token_perplexities]
390
+ })
391
+
392
+ return summary, viz_html, df
393
+
394
+ except Exception as e:
395
+ error_msg = ERROR_MESSAGES["processing_error"].format(error=str(e))
396
+ return error_msg, "", pd.DataFrame()
397
+
398
+ # Create Gradio interface
399
+ with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo:
400
+ gr.Markdown(f"# {UI_SETTINGS['title']}")
401
+ gr.Markdown(UI_SETTINGS["description"])
402
+
403
+ with gr.Row():
404
+ with gr.Column(scale=2):
405
+ text_input = gr.Textbox(
406
+ label="Input Text",
407
+ placeholder="Enter the text you want to analyze...",
408
+ lines=6,
409
+ max_lines=10
410
+ )
411
+
412
+ with gr.Row():
413
+ model_name = gr.Dropdown(
414
+ label="Model Name",
415
+ choices=DEFAULT_MODELS["decoder"] + DEFAULT_MODELS["encoder"],
416
+ value="gpt2",
417
+ allow_custom_value=True,
418
+ info="Select a model or enter a custom HuggingFace model name"
419
+ )
420
+
421
+ model_type = gr.Radio(
422
+ label="Model Type",
423
+ choices=["decoder", "encoder"],
424
+ value="decoder",
425
+ info="Decoder for causal LM, Encoder for masked LM"
426
+ )
427
+
428
+ # Advanced settings for encoder models
429
+ with gr.Row():
430
+ mask_probability = gr.Slider(
431
+ label="Mask Probability",
432
+ minimum=PROCESSING_SETTINGS["min_mask_probability"],
433
+ maximum=PROCESSING_SETTINGS["max_mask_probability"],
434
+ value=PROCESSING_SETTINGS["default_mask_probability"],
435
+ step=0.05,
436
+ visible=False,
437
+ info="Probability of masking each token per iteration (encoder only)"
438
+ )
439
+
440
+ min_samples = gr.Slider(
441
+ label="Min Samples per Token",
442
+ minimum=PROCESSING_SETTINGS["min_samples_range"][0],
443
+ maximum=PROCESSING_SETTINGS["min_samples_range"][1],
444
+ value=PROCESSING_SETTINGS["default_min_samples"],
445
+ step=5,
446
+ visible=False,
447
+ info="Minimum perplexity samples to collect per token (encoder only)"
448
+ )
449
+
450
+ analyze_btn = gr.Button("πŸ” Analyze Perplexity", variant="primary", size="lg")
451
+
452
+ with gr.Column(scale=3):
453
+ summary_output = gr.Markdown(label="Summary")
454
+ viz_output = gr.HTML(label="Perplexity Visualization")
455
+
456
+ # Full-width table
457
+ with gr.Row():
458
+ table_output = gr.Dataframe(
459
+ label="Detailed Token Results",
460
+ interactive=False,
461
+ wrap=True
462
+ )
463
+
464
+ # Update model dropdown based on type selection
465
+ def update_model_choices(model_type):
466
+ return gr.update(choices=DEFAULT_MODELS[model_type], value=DEFAULT_MODELS[model_type][0])
467
+
468
+ def toggle_advanced_settings(model_type):
469
+ is_encoder = (model_type == "encoder")
470
+ return [
471
+ gr.update(visible=is_encoder), # mask_probability
472
+ gr.update(visible=is_encoder) # min_samples
473
+ ]
474
+
475
+ model_type.change(
476
+ fn=lambda mt: [update_model_choices(mt)] + toggle_advanced_settings(mt),
477
+ inputs=[model_type],
478
+ outputs=[model_name, mask_probability, min_samples]
479
+ )
480
+
481
+ # Set up the analysis function
482
+ analyze_btn.click(
483
+ fn=process_text,
484
+ inputs=[text_input, model_name, model_type, mask_probability, min_samples],
485
+ outputs=[summary_output, viz_output, table_output]
486
+ )
487
+
488
+ # Add examples
489
+ with gr.Accordion("πŸ“ Example Texts", open=False):
490
+ examples_data = [
491
+ [ex["text"], ex["model"], ex["type"], ex.get("mask_prob", 0.15), ex.get("min_samples", 10)]
492
+ for ex in UI_SETTINGS["examples"]
493
+ ]
494
+
495
+ gr.Examples(
496
+ examples=examples_data,
497
+ inputs=[text_input, model_name, model_type, mask_probability, min_samples],
498
+ outputs=[summary_output, viz_output, table_output],
499
+ fn=process_text,
500
+ cache_examples=False,
501
+ label="Click on an example to try it out:"
502
+ )
503
+
504
+ # Add footer with information
505
+ gr.Markdown("""
506
+ ---
507
+
508
+ ### πŸ“Š How it works:
509
+
510
+ - **Decoder Models** (GPT, etc.): Calculate true perplexity by measuring how well the model predicts the next token
511
+ - **Encoder Models** (BERT, etc.): Calculate pseudo-perplexity using statistical sampling with multiple token masking
512
+ - **Mask Probability**: For encoder models, controls what fraction of tokens get masked in each iteration
513
+ - **Min Samples**: Minimum number of perplexity measurements collected per token for robust statistics
514
+ - **Color Coding**: Red = High perplexity (uncertain), Green = Low perplexity (confident)
515
+
516
+ ### ⚠️ Notes:
517
+ - First model load may take some time
518
+ - Models are cached after first use
519
+ - Very long texts are truncated to 512 tokens
520
+ - GPU acceleration is used when available
521
+ - Encoder models use Monte Carlo sampling for robust perplexity estimates
522
+ - Higher min samples = more accurate but slower analysis
523
+ """)
524
+
525
+ if __name__ == "__main__":
526
+ try:
527
+ demo.launch(
528
+ server_name="0.0.0.0",
529
+ server_port=7860,
530
+ show_api=False
531
+ )
532
+ except Exception as e:
533
+ print(f"❌ Failed to launch app: {e}")
534
+ print("πŸ’‘ Try running with: python run.py")
535
+ # Fallback to basic launch
536
+ try:
537
+ demo.launch()
538
+ except Exception as fallback_error:
539
+ print(f"❌ Fallback launch also failed: {fallback_error}")
540
+ print("πŸ’‘ Try updating Gradio: pip install --upgrade gradio")
launch.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple launcher for PerplexityViewer that handles common issues
4
+ """
5
+
6
+ import sys
7
+ import os
8
+
9
+ def main():
10
+ """Simple launcher with fallback options"""
11
+ print("πŸš€ Starting PerplexityViewer...")
12
+
13
+ try:
14
+ # Try importing required modules
15
+ import gradio as gr
16
+ print(f"βœ… Gradio version: {gr.__version__}")
17
+
18
+ # Import the app
19
+ from app import demo
20
+
21
+ # Launch with minimal configuration
22
+ print("🌐 Launching app at http://localhost:7860")
23
+ demo.launch()
24
+
25
+ except ImportError as e:
26
+ print(f"❌ Missing dependency: {e}")
27
+ print("πŸ’‘ Install requirements with: pip install -r requirements.txt")
28
+ sys.exit(1)
29
+
30
+ except Exception as e:
31
+ print(f"❌ Launch failed: {e}")
32
+ print("πŸ’‘ Trying alternative methods...")
33
+
34
+ # Try different launch approaches
35
+ try:
36
+ from app import demo
37
+ demo.launch(server_name="127.0.0.1", server_port=7860)
38
+ except:
39
+ try:
40
+ from app import demo
41
+ demo.launch(share=False, debug=True)
42
+ except:
43
+ print("❌ All launch methods failed")
44
+ print("πŸ’‘ Try running: python app.py directly")
45
+ sys.exit(1)
46
+
47
+ if __name__ == "__main__":
48
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.21.0
3
+ gradio>=4.0.0
4
+ pandas>=1.5.0
5
+ spacy>=3.4.0
6
+ numpy>=1.21.0
7
+ accelerate>=0.20.0
run.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Startup script for PerplexityViewer Gradio app
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import subprocess
9
+ import argparse
10
+ import warnings
11
+
12
+ # Suppress warnings
13
+ warnings.filterwarnings("ignore")
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
+ def check_dependencies():
17
+ """Check if required packages are installed"""
18
+ required_packages = [
19
+ "torch",
20
+ "transformers",
21
+ "gradio",
22
+ "pandas",
23
+ "spacy",
24
+ "numpy"
25
+ ]
26
+
27
+ missing_packages = []
28
+
29
+ for package in required_packages:
30
+ try:
31
+ __import__(package)
32
+ except ImportError:
33
+ missing_packages.append(package)
34
+
35
+ if missing_packages:
36
+ print("❌ Missing required packages:")
37
+ for package in missing_packages:
38
+ print(f" - {package}")
39
+ print("\nπŸ“¦ Install missing packages with:")
40
+ print(f" pip install {' '.join(missing_packages)}")
41
+ return False
42
+
43
+ print("βœ… All required packages are installed")
44
+ return True
45
+
46
+ def install_dependencies():
47
+ """Install dependencies from requirements.txt"""
48
+ print("πŸ“¦ Installing dependencies...")
49
+ try:
50
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
51
+ print("βœ… Dependencies installed successfully")
52
+ return True
53
+ except subprocess.CalledProcessError as e:
54
+ print(f"❌ Failed to install dependencies: {e}")
55
+ return False
56
+
57
+ def run_tests():
58
+ """Run the test suite"""
59
+ print("πŸ§ͺ Running tests...")
60
+ try:
61
+ result = subprocess.run([sys.executable, "test_app.py"],
62
+ capture_output=True, text=True)
63
+
64
+ if result.returncode == 0:
65
+ print("βœ… All tests passed")
66
+ return True
67
+ else:
68
+ print("❌ Some tests failed:")
69
+ print(result.stdout)
70
+ print(result.stderr)
71
+ return False
72
+ except FileNotFoundError:
73
+ print("⚠️ Test file not found, skipping tests")
74
+ return True
75
+
76
+ def launch_app(share=False, debug=False, port=7860):
77
+ """Launch the Gradio app"""
78
+ print("πŸš€ Starting PerplexityViewer...")
79
+
80
+ # Set environment variables
81
+ if debug:
82
+ os.environ["GRADIO_DEBUG"] = "1"
83
+
84
+ try:
85
+ # Import and launch the app
86
+ from app import demo
87
+
88
+ # Prepare launch arguments with version compatibility
89
+ launch_args = {
90
+ "server_name": "0.0.0.0" if not debug else "127.0.0.1",
91
+ "server_port": port,
92
+ "share": share,
93
+ "show_api": False
94
+ }
95
+
96
+ # Add quiet parameter only if supported (older Gradio versions)
97
+ try:
98
+ import gradio as gr
99
+ # Check if quiet parameter is supported
100
+ import inspect
101
+ launch_signature = inspect.signature(demo.launch)
102
+ if 'quiet' in launch_signature.parameters:
103
+ launch_args["quiet"] = not debug
104
+ except:
105
+ pass # If we can't check, just skip the quiet parameter
106
+
107
+ demo.launch(**launch_args)
108
+
109
+ except KeyboardInterrupt:
110
+ print("\nπŸ‘‹ Shutting down PerplexityViewer")
111
+ except Exception as e:
112
+ print(f"❌ Failed to launch app: {e}")
113
+ print("πŸ’‘ Try updating Gradio: pip install --upgrade gradio")
114
+ sys.exit(1)
115
+
116
+ def main():
117
+ """Main entry point"""
118
+ parser = argparse.ArgumentParser(
119
+ description="PerplexityViewer - Visualize text perplexity with color gradients",
120
+ formatter_class=argparse.RawDescriptionHelpFormatter,
121
+ epilog="""
122
+ Examples:
123
+ python run.py # Launch with default settings
124
+ python run.py --install # Install dependencies first
125
+ python run.py --test # Run tests before launching
126
+ python run.py --share # Create shareable link
127
+ python run.py --debug --port 8080 # Debug mode on custom port
128
+ """
129
+ )
130
+
131
+ parser.add_argument("--install", action="store_true",
132
+ help="Install dependencies before launching")
133
+ parser.add_argument("--test", action="store_true",
134
+ help="Run tests before launching")
135
+ parser.add_argument("--share", action="store_true",
136
+ help="Create a shareable Gradio link")
137
+ parser.add_argument("--debug", action="store_true",
138
+ help="Enable debug mode")
139
+ parser.add_argument("--port", type=int, default=7860,
140
+ help="Port to run the server on (default: 7860)")
141
+ parser.add_argument("--skip-checks", action="store_true",
142
+ help="Skip dependency checks")
143
+
144
+ args = parser.parse_args()
145
+
146
+ print("="*60)
147
+ print("🎯 PerplexityViewer Startup")
148
+ print("="*60)
149
+
150
+ # Install dependencies if requested
151
+ if args.install:
152
+ if not install_dependencies():
153
+ sys.exit(1)
154
+
155
+ # Check dependencies unless skipped
156
+ if not args.skip_checks:
157
+ if not check_dependencies():
158
+ print("\nπŸ’‘ Try running with --install to install missing packages")
159
+ sys.exit(1)
160
+
161
+ # Run tests if requested
162
+ if args.test:
163
+ if not run_tests():
164
+ print("\n⚠️ Tests failed, but continuing anyway...")
165
+ print(" Use Ctrl+C to cancel or wait to launch app")
166
+ try:
167
+ import time
168
+ time.sleep(3)
169
+ except KeyboardInterrupt:
170
+ print("\nπŸ‘‹ Cancelled")
171
+ sys.exit(0)
172
+
173
+ # Launch the app
174
+ print(f"\n🌐 App will be available at: http://localhost:{args.port}")
175
+ if args.share:
176
+ print("πŸ”— A shareable link will be created")
177
+
178
+ launch_app(share=args.share, debug=args.debug, port=args.port)
179
+
180
+ if __name__ == "__main__":
181
+ main()