Bram van Es commited on
Commit
5d8e89e
Β·
1 Parent(s): 3c38d1e

update for simple viewer

Browse files
Files changed (4) hide show
  1. app.py +125 -120
  2. config.py +113 -0
  3. launch.py +48 -0
  4. run.py +181 -0
app.py CHANGED
@@ -18,27 +18,33 @@ except ImportError:
18
  }
19
  MODEL_SETTINGS = {"max_length": 512}
20
  VIZ_SETTINGS = {
21
- "max_perplexity_display": 100.0,
22
  "color_scheme": {
23
- "high_perplexity": {"r": 255, "g": 0, "b": 50},
24
- "low_perplexity": {"r": 0, "g": 255, "b": 50}
 
 
 
 
 
 
 
25
  },
26
  "displacy_options": {"ents": ["PP"], "colors": {}}
27
  }
28
  PROCESSING_SETTINGS = {
29
  "default_iterations": 1,
30
  "max_iterations": 10,
31
- "default_mlm_probability": 0.15,
32
- "min_mlm_probability": 0.1,
33
- "max_mlm_probability": 0.5,
34
  "epsilon": 1e-10
35
  }
36
  UI_SETTINGS = {
37
- "title": "πŸ“ˆ Perplexity Viewer",
38
- "description": "Visualize per-token perplexity using color gradients.",
39
  "examples": [
40
- {"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder", "iterations": 1, "mlm_prob": 0.15},
41
- {"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder", "iterations": 1, "mlm_prob": 0.15}
 
 
42
  ]
43
  }
44
  ERROR_MESSAGES = {
@@ -138,8 +144,8 @@ def calculate_decoder_perplexity(text, model, tokenizer, iterations=1):
138
 
139
  return np.mean(perplexities), cleaned_tokens, token_perplexities
140
 
141
- def calculate_encoder_perplexity(text, model, tokenizer, mlm_probability=0.15, iterations=1):
142
- """Calculate pseudo-perplexity for encoder models (like BERT) using MLM"""
143
  device = next(model.parameters()).device
144
 
145
  perplexities = []
@@ -152,48 +158,32 @@ def calculate_encoder_perplexity(text, model, tokenizer, mlm_probability=0.15, i
152
  if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP]
153
  raise gr.Error("Text is too short for MLM perplexity calculation.")
154
 
155
- # Create a copy for masking
156
- masked_input_ids = input_ids.clone()
157
- original_tokens = input_ids.clone()
158
-
159
- # Randomly mask tokens (excluding special tokens)
160
- seq_length = input_ids.size(1)
161
- mask_indices = []
162
- special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
163
-
164
- for i in range(seq_length):
165
- if input_ids[0, i].item() not in special_token_ids:
166
- if torch.rand(1).item() < mlm_probability:
167
- mask_indices.append(i)
168
- masked_input_ids[0, i] = tokenizer.mask_token_id
169
-
170
- if not mask_indices:
171
- # If no tokens were masked, mask at least one non-special token
172
- non_special_indices = [i for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids]
173
- if non_special_indices:
174
- mask_idx = torch.randint(0, len(non_special_indices), (1,)).item()
175
- mask_indices = [non_special_indices[mask_idx]]
176
- masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id
177
-
178
  with torch.no_grad():
179
- outputs = model(masked_input_ids)
180
- predictions = outputs.logits
181
-
182
- # Calculate perplexity for masked tokens
183
- masked_token_losses = []
184
- for idx in mask_indices:
185
- target_id = original_tokens[0, idx]
186
- pred_scores = predictions[0, idx]
187
- prob = F.softmax(pred_scores, dim=-1)[target_id]
188
- loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"])
189
- masked_token_losses.append(loss.item())
190
-
191
- if masked_token_losses:
192
- avg_loss = np.mean(masked_token_losses)
 
 
 
 
 
 
193
  perplexity = math.exp(avg_loss)
194
  perplexities.append(perplexity)
195
 
196
- # Calculate per-token pseudo-perplexity for visualization
197
  with torch.no_grad():
198
  token_perplexities = []
199
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
@@ -203,6 +193,7 @@ def calculate_encoder_perplexity(text, model, tokenizer, mlm_probability=0.15, i
203
  if input_ids[0, i].item() in special_token_ids:
204
  token_perplexities.append(1.0) # Low perplexity for special tokens
205
  else:
 
206
  masked_input = input_ids.clone()
207
  original_token_id = input_ids[0, i]
208
  masked_input[0, i] = tokenizer.mask_token_id
@@ -224,7 +215,7 @@ def calculate_encoder_perplexity(text, model, tokenizer, mlm_probability=0.15, i
224
  return np.mean(perplexities) if perplexities else float('inf'), cleaned_tokens, np.array(token_perplexities)
225
 
226
  def create_visualization(tokens, perplexities):
227
- """Create displaCy visualization with color-coded perplexities"""
228
  if len(tokens) == 0:
229
  return "<p>No tokens to visualize.</p>"
230
 
@@ -234,10 +225,17 @@ def create_visualization(tokens, perplexities):
234
  # Normalize perplexities to 0-1 range for color mapping
235
  normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1)
236
 
237
- # Create entities for displaCy
238
- entities = []
239
- text_parts = []
240
- current_pos = 0
 
 
 
 
 
 
 
241
 
242
  for i, (token, perp, norm_perp) in enumerate(zip(tokens, perplexities, normalized_perplexities)):
243
  # Skip empty tokens
@@ -245,57 +243,82 @@ def create_visualization(tokens, perplexities):
245
  continue
246
 
247
  # Clean token for display
248
- clean_token = token.replace("</w>", "").strip()
249
  if not clean_token:
250
  continue
251
 
252
- # Add space before token if it's not the first one and doesn't start with punctuation
253
  if i > 0 and not clean_token[0] in ".,!?;:":
254
- text_parts.append(" ")
255
- current_pos += 1
256
 
257
- text_parts.append(clean_token)
 
 
258
 
259
- # Map perplexity to color
260
- high_color = VIZ_SETTINGS["color_scheme"]["high_perplexity"]
261
  low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"]
 
 
262
 
263
- red = int(high_color["r"] * norm_perp + low_color["r"] * (1 - norm_perp))
264
- green = int(high_color["g"] * norm_perp + low_color["g"] * (1 - norm_perp))
265
- blue = int(high_color["b"] * norm_perp + low_color["b"] * (1 - norm_perp))
266
-
267
- color = f"rgb({red}, {green}, {blue})"
268
-
269
- entities.append({
270
- "start": current_pos,
271
- "end": current_pos + len(clean_token),
272
- "label": f"{perp:.2f}",
273
- "color": color
274
- })
275
-
276
- current_pos += len(clean_token)
277
-
278
- # Join text parts
279
- text = "".join(text_parts)
280
-
281
- if not entities:
282
- return "<p>No valid tokens found for visualization.</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- # Create displaCy data structure
285
- doc_data = {
286
- "text": text,
287
- "ents": entities,
288
- "title": "Per-token Perplexity Visualization"
289
- }
 
 
290
 
291
- try:
292
- # Generate HTML
293
- html = displacy.render(doc_data, style="ent", manual=True, options=VIZ_SETTINGS["displacy_options"])
294
- return html
295
- except Exception as e:
296
- return f"<p>Error creating visualization: {str(e)}</p>"
297
 
298
- def process_text(text, model_name, model_type, iterations, mlm_probability):
299
  """Main processing function"""
300
  if not text.strip():
301
  return ERROR_MESSAGES["empty_text"], "", pd.DataFrame()
@@ -303,8 +326,6 @@ def process_text(text, model_name, model_type, iterations, mlm_probability):
303
  try:
304
  # Validate inputs
305
  iterations = max(1, min(iterations, PROCESSING_SETTINGS["max_iterations"]))
306
- mlm_probability = max(PROCESSING_SETTINGS["min_mlm_probability"],
307
- min(mlm_probability, PROCESSING_SETTINGS["max_mlm_probability"]))
308
 
309
  # Load model and tokenizer
310
  model, tokenizer = load_model_and_tokenizer(model_name, model_type)
@@ -316,7 +337,7 @@ def process_text(text, model_name, model_type, iterations, mlm_probability):
316
  )
317
  else: # encoder
318
  avg_perplexity, tokens, token_perplexities = calculate_encoder_perplexity(
319
- text, model, tokenizer, mlm_probability, iterations
320
  )
321
 
322
  # Create visualization
@@ -333,8 +354,6 @@ def process_text(text, model_name, model_type, iterations, mlm_probability):
333
  **Iterations:** {iterations}
334
  """
335
 
336
- if model_type == "encoder":
337
- summary += f" \n**MLM Probability:** {mlm_probability}"
338
 
339
  # Create detailed results table
340
  df = pd.DataFrame({
@@ -387,17 +406,6 @@ with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo:
387
  step=1,
388
  info="Number of iterations to average over"
389
  )
390
-
391
- mlm_probability = gr.Slider(
392
- label="MLM Probability",
393
- minimum=PROCESSING_SETTINGS["min_mlm_probability"],
394
- maximum=PROCESSING_SETTINGS["max_mlm_probability"],
395
- value=PROCESSING_SETTINGS["default_mlm_probability"],
396
- step=0.05,
397
- visible=False,
398
- info="Probability of masking tokens (encoder models only)"
399
- )
400
-
401
  analyze_btn = gr.Button("πŸ” Analyze Perplexity", variant="primary", size="lg")
402
 
403
  with gr.Column(scale=3):
@@ -416,33 +424,29 @@ with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo:
416
  def update_model_choices(model_type):
417
  return gr.update(choices=DEFAULT_MODELS[model_type], value=DEFAULT_MODELS[model_type][0])
418
 
419
- # Show/hide MLM probability based on model type
420
- def toggle_mlm_visibility(model_type):
421
- return gr.update(visible=(model_type == "encoder"))
422
-
423
  model_type.change(
424
- fn=lambda mt: [update_model_choices(mt), toggle_mlm_visibility(mt)],
425
  inputs=[model_type],
426
- outputs=[model_name, mlm_probability]
427
  )
428
 
429
  # Set up the analysis function
430
  analyze_btn.click(
431
  fn=process_text,
432
- inputs=[text_input, model_name, model_type, iterations, mlm_probability],
433
  outputs=[summary_output, viz_output, table_output]
434
  )
435
 
436
  # Add examples
437
  with gr.Accordion("πŸ“ Example Texts", open=False):
438
  examples_data = [
439
- [ex["text"], ex["model"], ex["type"], ex["iterations"], ex["mlm_prob"]]
440
  for ex in UI_SETTINGS["examples"]
441
  ]
442
 
443
  gr.Examples(
444
  examples=examples_data,
445
- inputs=[text_input, model_name, model_type, iterations, mlm_probability],
446
  outputs=[summary_output, viz_output, table_output],
447
  fn=process_text,
448
  cache_examples=False,
@@ -464,6 +468,7 @@ with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo:
464
  - Models are cached after first use
465
  - Very long texts are truncated to 512 tokens
466
  - GPU acceleration is used when available
 
467
  """)
468
 
469
  if __name__ == "__main__":
 
18
  }
19
  MODEL_SETTINGS = {"max_length": 512}
20
  VIZ_SETTINGS = {
21
+ "max_perplexity_display": 50.0,
22
  "color_scheme": {
23
+ "low_perplexity": {"r": 46, "g": 204, "b": 113},
24
+ "medium_perplexity": {"r": 241, "g": 196, "b": 15},
25
+ "high_perplexity": {"r": 231, "g": 76, "b": 60},
26
+ "background_alpha": 0.7,
27
+ "border_alpha": 0.9
28
+ },
29
+ "thresholds": {
30
+ "low_threshold": 0.3,
31
+ "high_threshold": 0.7
32
  },
33
  "displacy_options": {"ents": ["PP"], "colors": {}}
34
  }
35
  PROCESSING_SETTINGS = {
36
  "default_iterations": 1,
37
  "max_iterations": 10,
 
 
 
38
  "epsilon": 1e-10
39
  }
40
  UI_SETTINGS = {
41
+ "title": "πŸ“ˆ Perplexity Viewer Simple",
42
+ "description": "Visualize per-token perplexity using color gradients. Assumes single token masking.",
43
  "examples": [
44
+ {"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder", "iterations": 1},
45
+ {"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder", "iterations": 1},
46
+ {"text": "Quantum entanglement defies classical physics intuition completely.", "model": "distilgpt2", "type": "decoder", "iterations": 1},
47
+ {"text": "Machine learning algorithms require computational resources.", "model": "distilbert-base-uncased", "type": "encoder", "iterations": 1}
48
  ]
49
  }
50
  ERROR_MESSAGES = {
 
144
 
145
  return np.mean(perplexities), cleaned_tokens, token_perplexities
146
 
147
+ def calculate_encoder_perplexity(text, model, tokenizer, iterations=1):
148
+ """Calculate pseudo-perplexity for encoder models (like BERT) using MLM on all tokens"""
149
  device = next(model.parameters()).device
150
 
151
  perplexities = []
 
158
  if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP]
159
  raise gr.Error("Text is too short for MLM perplexity calculation.")
160
 
161
+ # Calculate average perplexity by masking all content tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  with torch.no_grad():
163
+ seq_length = input_ids.size(1)
164
+ special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}
165
+
166
+ all_token_losses = []
167
+
168
+ # Mask each non-special token individually and calculate loss
169
+ for i in range(seq_length):
170
+ if input_ids[0, i].item() not in special_token_ids:
171
+ masked_input = input_ids.clone()
172
+ original_token_id = input_ids[0, i]
173
+ masked_input[0, i] = tokenizer.mask_token_id
174
+
175
+ outputs = model(masked_input)
176
+ predictions = outputs.logits[0, i]
177
+ prob = F.softmax(predictions, dim=-1)[original_token_id]
178
+ loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"])
179
+ all_token_losses.append(loss.item())
180
+
181
+ if all_token_losses:
182
+ avg_loss = np.mean(all_token_losses)
183
  perplexity = math.exp(avg_loss)
184
  perplexities.append(perplexity)
185
 
186
+ # Calculate per-token pseudo-perplexity for visualization (analyze all tokens)
187
  with torch.no_grad():
188
  token_perplexities = []
189
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
 
193
  if input_ids[0, i].item() in special_token_ids:
194
  token_perplexities.append(1.0) # Low perplexity for special tokens
195
  else:
196
+ # Calculate detailed perplexity for every content token
197
  masked_input = input_ids.clone()
198
  original_token_id = input_ids[0, i]
199
  masked_input[0, i] = tokenizer.mask_token_id
 
215
  return np.mean(perplexities) if perplexities else float('inf'), cleaned_tokens, np.array(token_perplexities)
216
 
217
  def create_visualization(tokens, perplexities):
218
+ """Create custom HTML visualization with color-coded perplexities"""
219
  if len(tokens) == 0:
220
  return "<p>No tokens to visualize.</p>"
221
 
 
225
  # Normalize perplexities to 0-1 range for color mapping
226
  normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1)
227
 
228
+ # Create HTML with inline styles for color coding
229
+ html_parts = [
230
+ '<div style="font-family: Arial, sans-serif; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background-color: #fafafa;">',
231
+ '<h3 style="margin-top: 0; color: #333;">Per-token Perplexity Visualization</h3>',
232
+ '<div style="margin-bottom: 15px;">',
233
+ '<span style="font-size: 12px; color: #666;">',
234
+ '🟒 Low perplexity (confident) β†’ 🟑 Medium β†’ πŸ”΄ High perplexity (uncertain)',
235
+ '</span>',
236
+ '</div>',
237
+ '<div style="line-height: 2.0;">'
238
+ ]
239
 
240
  for i, (token, perp, norm_perp) in enumerate(zip(tokens, perplexities, normalized_perplexities)):
241
  # Skip empty tokens
 
243
  continue
244
 
245
  # Clean token for display
246
+ clean_token = token.replace("</w>", "").replace("##", "").strip()
247
  if not clean_token:
248
  continue
249
 
250
+ # Add space before token if needed
251
  if i > 0 and not clean_token[0] in ".,!?;:":
252
+ html_parts.append(" ")
 
253
 
254
+ # Get color thresholds from configuration
255
+ low_thresh = VIZ_SETTINGS.get("thresholds", {}).get("low_threshold", 0.3)
256
+ high_thresh = VIZ_SETTINGS.get("thresholds", {}).get("high_threshold", 0.7)
257
 
258
+ # Get colors from configuration
 
259
  low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"]
260
+ med_color = VIZ_SETTINGS["color_scheme"]["medium_perplexity"]
261
+ high_color = VIZ_SETTINGS["color_scheme"]["high_perplexity"]
262
 
263
+ # Map perplexity to color using configuration
264
+ if norm_perp < low_thresh: # Low perplexity - green
265
+ # Interpolate between green and yellow
266
+ factor = norm_perp / low_thresh
267
+ red = int(low_color["r"] + factor * (med_color["r"] - low_color["r"]))
268
+ green = int(low_color["g"] + factor * (med_color["g"] - low_color["g"]))
269
+ blue = int(low_color["b"] + factor * (med_color["b"] - low_color["b"]))
270
+ elif norm_perp < high_thresh: # Medium perplexity - yellow/orange
271
+ # Interpolate between yellow and red
272
+ factor = (norm_perp - low_thresh) / (high_thresh - low_thresh)
273
+ red = int(med_color["r"] + factor * (high_color["r"] - med_color["r"]))
274
+ green = int(med_color["g"] + factor * (high_color["g"] - med_color["g"]))
275
+ blue = int(med_color["b"] + factor * (high_color["b"] - med_color["b"]))
276
+ else: # High perplexity - red
277
+ # Use high perplexity color, potentially darker for very high values
278
+ factor = min((norm_perp - high_thresh) / (1.0 - high_thresh), 1.0)
279
+ darken = 0.8 - (factor * 0.3) # Darken by up to 30%
280
+ red = int(high_color["r"] * darken)
281
+ green = int(high_color["g"] * darken)
282
+ blue = int(high_color["b"] * darken)
283
+
284
+ tooltip_text = f"Perplexity: {perp:.3f} (normalized: {norm_perp:.3f})"
285
+
286
+ # Clamp values
287
+ red = max(0, min(255, red))
288
+ green = max(0, min(255, green))
289
+ blue = max(0, min(255, blue))
290
+
291
+ # Get alpha values from configuration
292
+ bg_alpha = VIZ_SETTINGS["color_scheme"].get("background_alpha", 0.7)
293
+ border_alpha = VIZ_SETTINGS["color_scheme"].get("border_alpha", 0.9)
294
+
295
+ # Create colored span with tooltip
296
+ html_parts.append(
297
+ f'<span style="'
298
+ f'background-color: rgba({red}, {green}, {blue}, {bg_alpha}); '
299
+ f'color: #000; '
300
+ f'padding: 2px 4px; '
301
+ f'margin: 1px; '
302
+ f'border-radius: 3px; '
303
+ f'border: 1px solid rgba({red}, {green}, {blue}, {border_alpha}); '
304
+ f'font-weight: 500; '
305
+ f'cursor: help; '
306
+ f'display: inline-block;'
307
+ f'" title="{tooltip_text}">{clean_token}</span>'
308
+ )
309
 
310
+ html_parts.extend([
311
+ '</div>',
312
+ '<div style="margin-top: 15px; font-size: 12px; color: #666;">',
313
+ f'Max perplexity in visualization: {max_perplexity:.2f} | ',
314
+ f'Total tokens: {len(tokens)}',
315
+ '</div>',
316
+ '</div>'
317
+ ])
318
 
319
+ return "".join(html_parts)
 
 
 
 
 
320
 
321
+ def process_text(text, model_name, model_type, iterations):
322
  """Main processing function"""
323
  if not text.strip():
324
  return ERROR_MESSAGES["empty_text"], "", pd.DataFrame()
 
326
  try:
327
  # Validate inputs
328
  iterations = max(1, min(iterations, PROCESSING_SETTINGS["max_iterations"]))
 
 
329
 
330
  # Load model and tokenizer
331
  model, tokenizer = load_model_and_tokenizer(model_name, model_type)
 
337
  )
338
  else: # encoder
339
  avg_perplexity, tokens, token_perplexities = calculate_encoder_perplexity(
340
+ text, model, tokenizer, iterations
341
  )
342
 
343
  # Create visualization
 
354
  **Iterations:** {iterations}
355
  """
356
 
 
 
357
 
358
  # Create detailed results table
359
  df = pd.DataFrame({
 
406
  step=1,
407
  info="Number of iterations to average over"
408
  )
 
 
 
 
 
 
 
 
 
 
 
409
  analyze_btn = gr.Button("πŸ” Analyze Perplexity", variant="primary", size="lg")
410
 
411
  with gr.Column(scale=3):
 
424
  def update_model_choices(model_type):
425
  return gr.update(choices=DEFAULT_MODELS[model_type], value=DEFAULT_MODELS[model_type][0])
426
 
 
 
 
 
427
  model_type.change(
428
+ fn=update_model_choices,
429
  inputs=[model_type],
430
+ outputs=[model_name]
431
  )
432
 
433
  # Set up the analysis function
434
  analyze_btn.click(
435
  fn=process_text,
436
+ inputs=[text_input, model_name, model_type, iterations],
437
  outputs=[summary_output, viz_output, table_output]
438
  )
439
 
440
  # Add examples
441
  with gr.Accordion("πŸ“ Example Texts", open=False):
442
  examples_data = [
443
+ [ex["text"], ex["model"], ex["type"], ex["iterations"]]
444
  for ex in UI_SETTINGS["examples"]
445
  ]
446
 
447
  gr.Examples(
448
  examples=examples_data,
449
+ inputs=[text_input, model_name, model_type, iterations],
450
  outputs=[summary_output, viz_output, table_output],
451
  fn=process_text,
452
  cache_examples=False,
 
468
  - Models are cached after first use
469
  - Very long texts are truncated to 512 tokens
470
  - GPU acceleration is used when available
471
+ - For encoder models, all content tokens are analyzed for comprehensive results
472
  """)
473
 
474
  if __name__ == "__main__":
config.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for PerplexityViewer
2
+
3
+ # Default models for different types
4
+ DEFAULT_MODELS = {
5
+ "decoder": [
6
+ "gpt2",
7
+ "distilgpt2",
8
+ "microsoft/DialoGPT-small",
9
+ "microsoft/DialoGPT-medium",
10
+ "openai-gpt"
11
+ ],
12
+ "encoder": [
13
+ "bert-base-uncased",
14
+ "bert-base-cased",
15
+ "distilbert-base-uncased",
16
+ "roberta-base",
17
+ "albert-base-v2"
18
+ "UMCU/CardioMedRoBERTa.nl",
19
+ "UMCU/CardioBERTa.nl",
20
+ "UMCU/CardioBERTa.nl_clinical",
21
+ "CLTL/MedRoBERTa.nl",
22
+ "DTAI-KULeuven/robbert-2023-dutch-base",
23
+ "DTAI-KULeuven/robbert-2023-dutch-large"
24
+ ]
25
+ }
26
+
27
+ # Model display settings
28
+ MODEL_SETTINGS = {
29
+ "max_length": 512,
30
+ "torch_dtype": "float16",
31
+ "device_map": "auto"
32
+ }
33
+
34
+ # Visualization settings
35
+ VIZ_SETTINGS = {
36
+ "max_perplexity_display": 50.0, # Cap visualization at this perplexity value
37
+ "color_scheme": {
38
+ "low_perplexity": {"r": 46, "g": 204, "b": 113}, # Green for low perplexity (confident)
39
+ "medium_perplexity": {"r": 241, "g": 196, "b": 15}, # Yellow for medium perplexity
40
+ "high_perplexity": {"r": 231, "g": 76, "b": 60}, # Red for high perplexity (uncertain)
41
+ "background_alpha": 0.7, # Background transparency
42
+ "border_alpha": 0.9 # Border transparency
43
+ },
44
+ "thresholds": {
45
+ "low_threshold": 0.3, # Below this is low perplexity (green)
46
+ "high_threshold": 0.7 # Above this is high perplexity (red)
47
+ },
48
+ "displacy_options": {
49
+ "ents": ["PP"],
50
+ "colors": {}
51
+ }
52
+ }
53
+
54
+ # Processing settings
55
+ PROCESSING_SETTINGS = {
56
+ "default_iterations": 1,
57
+ "max_iterations": 10,
58
+ "epsilon": 1e-10 # Small value to avoid log(0)
59
+ }
60
+
61
+ # UI settings
62
+ UI_SETTINGS = {
63
+ "theme": "soft",
64
+ "title": "πŸ“ˆ Perplexity Viewer",
65
+ "description": """
66
+ Visualize per-token perplexity using color gradients.
67
+ - **Red**: High perplexity (model is uncertain)
68
+ - **Green**: Low perplexity (model is confident)
69
+
70
+ Choose between decoder models (like GPT) for true perplexity or encoder models (like BERT) for pseudo-perplexity via MLM.
71
+ """,
72
+ "examples": [
73
+ {
74
+ "text": "The quick brown fox jumps over the lazy dog.",
75
+ "model": "gpt2",
76
+ "type": "decoder",
77
+ "iterations": 1
78
+ },
79
+ {
80
+ "text": "The capital of France is Paris.",
81
+ "model": "bert-base-uncased",
82
+ "type": "encoder",
83
+ "iterations": 1
84
+ },
85
+ {
86
+ "text": "Quantum entanglement defies classical physics intuition completely.",
87
+ "model": "distilgpt2",
88
+ "type": "decoder",
89
+ "iterations": 1
90
+ },
91
+ {
92
+ "text": "Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo.",
93
+ "model": "gpt2",
94
+ "type": "decoder",
95
+ "iterations": 1
96
+ },
97
+ {
98
+ "text": "Machine learning algorithms require computational resources.",
99
+ "model": "distilbert-base-uncased",
100
+ "type": "encoder",
101
+ "iterations": 1
102
+ }
103
+ ]
104
+ }
105
+
106
+ # Error messages
107
+ ERROR_MESSAGES = {
108
+ "empty_text": "Please enter some text to analyze.",
109
+ "model_load_error": "Error loading model {model_name}: {error}",
110
+ "processing_error": "Error processing text: {error}",
111
+ "no_tokens_masked": "No tokens were masked during MLM processing.",
112
+ "invalid_model_type": "Invalid model type. Must be 'encoder' or 'decoder'."
113
+ }
launch.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple launcher for PerplexityViewer that handles common issues
4
+ """
5
+
6
+ import sys
7
+ import os
8
+
9
+ def main():
10
+ """Simple launcher with fallback options"""
11
+ print("πŸš€ Starting PerplexityViewer...")
12
+
13
+ try:
14
+ # Try importing required modules
15
+ import gradio as gr
16
+ print(f"βœ… Gradio version: {gr.__version__}")
17
+
18
+ # Import the app
19
+ from app import demo
20
+
21
+ # Launch with minimal configuration
22
+ print("🌐 Launching app at http://localhost:7860")
23
+ demo.launch()
24
+
25
+ except ImportError as e:
26
+ print(f"❌ Missing dependency: {e}")
27
+ print("πŸ’‘ Install requirements with: pip install -r requirements.txt")
28
+ sys.exit(1)
29
+
30
+ except Exception as e:
31
+ print(f"❌ Launch failed: {e}")
32
+ print("πŸ’‘ Trying alternative methods...")
33
+
34
+ # Try different launch approaches
35
+ try:
36
+ from app import demo
37
+ demo.launch(server_name="127.0.0.1", server_port=7860)
38
+ except:
39
+ try:
40
+ from app import demo
41
+ demo.launch(share=False, debug=True)
42
+ except:
43
+ print("❌ All launch methods failed")
44
+ print("πŸ’‘ Try running: python app.py directly")
45
+ sys.exit(1)
46
+
47
+ if __name__ == "__main__":
48
+ main()
run.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Startup script for PerplexityViewer Gradio app
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import subprocess
9
+ import argparse
10
+ import warnings
11
+
12
+ # Suppress warnings
13
+ warnings.filterwarnings("ignore")
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
+ def check_dependencies():
17
+ """Check if required packages are installed"""
18
+ required_packages = [
19
+ "torch",
20
+ "transformers",
21
+ "gradio",
22
+ "pandas",
23
+ "spacy",
24
+ "numpy"
25
+ ]
26
+
27
+ missing_packages = []
28
+
29
+ for package in required_packages:
30
+ try:
31
+ __import__(package)
32
+ except ImportError:
33
+ missing_packages.append(package)
34
+
35
+ if missing_packages:
36
+ print("❌ Missing required packages:")
37
+ for package in missing_packages:
38
+ print(f" - {package}")
39
+ print("\nπŸ“¦ Install missing packages with:")
40
+ print(f" pip install {' '.join(missing_packages)}")
41
+ return False
42
+
43
+ print("βœ… All required packages are installed")
44
+ return True
45
+
46
+ def install_dependencies():
47
+ """Install dependencies from requirements.txt"""
48
+ print("πŸ“¦ Installing dependencies...")
49
+ try:
50
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
51
+ print("βœ… Dependencies installed successfully")
52
+ return True
53
+ except subprocess.CalledProcessError as e:
54
+ print(f"❌ Failed to install dependencies: {e}")
55
+ return False
56
+
57
+ def run_tests():
58
+ """Run the test suite"""
59
+ print("πŸ§ͺ Running tests...")
60
+ try:
61
+ result = subprocess.run([sys.executable, "test_app.py"],
62
+ capture_output=True, text=True)
63
+
64
+ if result.returncode == 0:
65
+ print("βœ… All tests passed")
66
+ return True
67
+ else:
68
+ print("❌ Some tests failed:")
69
+ print(result.stdout)
70
+ print(result.stderr)
71
+ return False
72
+ except FileNotFoundError:
73
+ print("⚠️ Test file not found, skipping tests")
74
+ return True
75
+
76
+ def launch_app(share=False, debug=False, port=7860):
77
+ """Launch the Gradio app"""
78
+ print("πŸš€ Starting PerplexityViewer...")
79
+
80
+ # Set environment variables
81
+ if debug:
82
+ os.environ["GRADIO_DEBUG"] = "1"
83
+
84
+ try:
85
+ # Import and launch the app
86
+ from app import demo
87
+
88
+ # Prepare launch arguments with version compatibility
89
+ launch_args = {
90
+ "server_name": "0.0.0.0" if not debug else "127.0.0.1",
91
+ "server_port": port,
92
+ "share": share,
93
+ "show_api": False
94
+ }
95
+
96
+ # Add quiet parameter only if supported (older Gradio versions)
97
+ try:
98
+ import gradio as gr
99
+ # Check if quiet parameter is supported
100
+ import inspect
101
+ launch_signature = inspect.signature(demo.launch)
102
+ if 'quiet' in launch_signature.parameters:
103
+ launch_args["quiet"] = not debug
104
+ except:
105
+ pass # If we can't check, just skip the quiet parameter
106
+
107
+ demo.launch(**launch_args)
108
+
109
+ except KeyboardInterrupt:
110
+ print("\nπŸ‘‹ Shutting down PerplexityViewer")
111
+ except Exception as e:
112
+ print(f"❌ Failed to launch app: {e}")
113
+ print("πŸ’‘ Try updating Gradio: pip install --upgrade gradio")
114
+ sys.exit(1)
115
+
116
+ def main():
117
+ """Main entry point"""
118
+ parser = argparse.ArgumentParser(
119
+ description="PerplexityViewer - Visualize text perplexity with color gradients",
120
+ formatter_class=argparse.RawDescriptionHelpFormatter,
121
+ epilog="""
122
+ Examples:
123
+ python run.py # Launch with default settings
124
+ python run.py --install # Install dependencies first
125
+ python run.py --test # Run tests before launching
126
+ python run.py --share # Create shareable link
127
+ python run.py --debug --port 8080 # Debug mode on custom port
128
+ """
129
+ )
130
+
131
+ parser.add_argument("--install", action="store_true",
132
+ help="Install dependencies before launching")
133
+ parser.add_argument("--test", action="store_true",
134
+ help="Run tests before launching")
135
+ parser.add_argument("--share", action="store_true",
136
+ help="Create a shareable Gradio link")
137
+ parser.add_argument("--debug", action="store_true",
138
+ help="Enable debug mode")
139
+ parser.add_argument("--port", type=int, default=7860,
140
+ help="Port to run the server on (default: 7860)")
141
+ parser.add_argument("--skip-checks", action="store_true",
142
+ help="Skip dependency checks")
143
+
144
+ args = parser.parse_args()
145
+
146
+ print("="*60)
147
+ print("🎯 PerplexityViewer Startup")
148
+ print("="*60)
149
+
150
+ # Install dependencies if requested
151
+ if args.install:
152
+ if not install_dependencies():
153
+ sys.exit(1)
154
+
155
+ # Check dependencies unless skipped
156
+ if not args.skip_checks:
157
+ if not check_dependencies():
158
+ print("\nπŸ’‘ Try running with --install to install missing packages")
159
+ sys.exit(1)
160
+
161
+ # Run tests if requested
162
+ if args.test:
163
+ if not run_tests():
164
+ print("\n⚠️ Tests failed, but continuing anyway...")
165
+ print(" Use Ctrl+C to cancel or wait to launch app")
166
+ try:
167
+ import time
168
+ time.sleep(3)
169
+ except KeyboardInterrupt:
170
+ print("\nπŸ‘‹ Cancelled")
171
+ sys.exit(0)
172
+
173
+ # Launch the app
174
+ print(f"\n🌐 App will be available at: http://localhost:{args.port}")
175
+ if args.share:
176
+ print("πŸ”— A shareable link will be created")
177
+
178
+ launch_app(share=args.share, debug=args.debug, port=args.port)
179
+
180
+ if __name__ == "__main__":
181
+ main()