Chr-Hau commited on
Commit
c5738ac
·
verified ·
1 Parent(s): e1b854f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -256
app.py CHANGED
@@ -1,10 +1,15 @@
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
 
7
- # Initialize model
 
 
8
  MODEL_NAME = "microsoft/Phi-4-mini-instruct"
9
 
10
  print("Loading model...")
@@ -13,10 +18,13 @@ model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_NAME,
14
  torch_dtype="auto",
15
  device_map="auto",
16
- trust_remote_code=True
17
  )
 
18
 
19
- # Define answer format configurations
 
 
20
  ANSWER_FORMATS = {
21
  "1-5 (numeric)": {
22
  "options": ["1", "2", "3", "4", "5"],
@@ -66,371 +74,437 @@ ANSWER_FORMATS = {
66
  }
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def get_token_info(options):
70
- """Get detailed token information for each option"""
71
  token_info = []
72
  for i, option in enumerate(options):
73
  tokens = tokenizer.encode(option, add_special_tokens=False)
74
  token_info.append({
75
- 'index': i,
76
- 'option': option,
77
- 'tokens': tokens,
78
- 'token_count': len(tokens),
79
- 'decoded_tokens': [tokenizer.decode([t]) for t in tokens]
80
  })
81
  return token_info
82
 
83
 
84
- def calculate_sequence_probability(prompt_ids, option_tokens, temperature=1.0):
85
  """
86
- Calculate the joint probability of generating a complete token sequence.
87
-
88
- Returns multiple probability metrics:
89
- - joint_prob: P(token1, token2, ..., tokenN)
90
- - geometric_mean: (joint_prob)^(1/N) - normalized for length
91
- - first_token_prob: P(token1) - for comparison with single-token options
92
- - avg_prob: arithmetic mean of individual token probabilities
93
  """
94
- if len(option_tokens) == 0:
95
  return None
96
-
97
- joint_prob = 1.0
 
 
 
98
  token_probs = []
99
-
100
- current_input = prompt_ids.clone()
101
-
102
- for token_id in option_tokens:
103
  with torch.no_grad():
104
  outputs = model(current_input)
105
- logits = outputs.logits[0, -1, :] / temperature
106
- probs = torch.softmax(logits, dim=-1)
107
-
108
- token_prob = probs[token_id].item()
109
- token_probs.append(token_prob)
110
- joint_prob *= token_prob
111
-
112
- # Append token for next iteration
113
- current_input = torch.cat([current_input, torch.tensor([[token_id]])], dim=1)
114
-
115
- n_tokens = len(option_tokens)
116
-
 
 
 
 
 
 
 
 
 
 
117
  return {
118
- 'joint_prob': joint_prob,
119
- 'geometric_mean': joint_prob ** (1.0 / n_tokens),
120
- 'first_token_prob': token_probs[0] if token_probs else 0.0,
121
- 'avg_prob': np.mean(token_probs) if token_probs else 0.0,
122
- 'token_probs': token_probs,
123
- 'perplexity': np.exp(-np.mean([np.log(p) for p in token_probs])) if token_probs else float('inf')
 
 
 
124
  }
125
 
126
 
127
- def create_comparison_plot(all_results, statement, metric='geometric_mean'):
128
- """Create a comprehensive comparison plot using specified metric"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  n_formats = len(all_results)
130
  if n_formats == 0:
131
  return None
132
-
133
- fig, axes = plt.subplots(2, (n_formats + 1) // 2, figsize=(16, 8))
134
- if n_formats == 1:
135
- axes = [axes]
136
- else:
137
- axes = axes.flatten()
138
-
139
  metric_names = {
140
- 'geometric_mean': 'Geometric Mean Probability',
141
- 'joint_prob': 'Joint Probability',
142
- 'first_token_prob': 'First Token Probability',
143
- 'avg_prob': 'Average Token Probability'
144
  }
145
-
146
  for idx, (format_name, data) in enumerate(all_results.items()):
147
  ax = axes[idx]
148
- labels = data['labels']
149
- probabilities = [p[metric] for p in data['probabilities']]
150
-
151
- bars = ax.bar(range(len(labels)), probabilities, color='steelblue', alpha=0.8, edgecolor='navy')
152
-
153
- # Add value labels
154
- for bar, prob in zip(bars, probabilities):
155
  height = bar.get_height()
156
- ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
157
- f'{prob:.3f}', ha='center', va='bottom', fontsize=8)
158
-
159
- ax.set_ylabel('Probability', fontsize=9)
160
- ax.set_title(format_name, fontsize=10, fontweight='bold')
 
 
 
 
 
 
161
  ax.set_xticks(range(len(labels)))
162
- ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
163
- ax.set_ylim(0, max(probabilities) * 1.2 if probabilities and max(probabilities) > 0 else 1)
164
- ax.grid(True, axis='y', alpha=0.3)
165
-
166
- # Hide unused subplots
167
- for idx in range(len(all_results), len(axes)):
168
- axes[idx].set_visible(False)
169
-
170
- plt.suptitle(f'Response Distribution Comparison ({metric_names[metric]})\nStatement: {statement[:80]}{"..." if len(statement) > 80 else ""}',
171
- fontsize=12, fontweight='bold')
 
 
 
 
172
  plt.tight_layout()
173
  return fig
174
 
175
 
176
- def create_heatmap(all_results, metric='geometric_mean'):
177
- """Create a heatmap showing probability distributions across formats"""
178
  format_names = list(all_results.keys())
179
- if len(format_names) == 0:
180
  return None
181
-
182
  n_options = 5
183
-
184
- prob_matrix = np.zeros((len(format_names), n_options))
185
- for i, (format_name, data) in enumerate(all_results.items()):
186
- prob_matrix[i] = [p[metric] for p in data['probabilities']]
187
-
188
  fig, ax = plt.subplots(figsize=(10, 8))
189
- im = ax.imshow(prob_matrix, cmap='YlOrRd', aspect='auto', vmin=0, vmax=np.max(prob_matrix))
190
-
191
  ax.set_xticks(range(n_options))
192
- ax.set_xticklabels(['Option 1', 'Option 2', 'Option 3', 'Option 4', 'Option 5'])
193
  ax.set_yticks(range(len(format_names)))
194
  ax.set_yticklabels(format_names, fontsize=9)
195
-
196
- # Add probability values
197
- for i in range(len(format_names)):
198
- for j in range(n_options):
199
- text = ax.text(j, i, f'{prob_matrix[i, j]:.3f}',
200
- ha="center", va="center", color="black", fontsize=8)
201
-
202
  metric_names = {
203
- 'geometric_mean': 'Geometric Mean',
204
- 'joint_prob': 'Joint Probability',
205
- 'first_token_prob': 'First Token Probability',
206
- 'avg_prob': 'Average Token Probability'
207
  }
208
-
209
- ax.set_title(f'Probability Heatmap: All Formats ({metric_names[metric]})', fontsize=12, fontweight='bold')
210
- plt.colorbar(im, ax=ax, label='Probability')
211
  plt.tight_layout()
212
  return fig
213
 
214
 
215
  def create_metric_comparison_plot(all_results, statement):
216
- """Create a plot comparing different probability metrics"""
217
- metrics = ['geometric_mean', 'joint_prob', 'first_token_prob', 'avg_prob']
218
- metric_names = ['Geometric Mean', 'Joint Prob', 'First Token', 'Avg Token']
219
-
220
- n_formats = len(all_results)
221
- if n_formats == 0:
 
 
222
  return None
223
-
224
  fig, axes = plt.subplots(2, 2, figsize=(14, 10))
225
- axes = axes.flatten()
226
-
227
- for metric_idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
228
- ax = axes[metric_idx]
229
-
230
  for format_name, data in all_results.items():
231
- labels = data['labels']
232
- probabilities = [p[metric] for p in data['probabilities']]
233
- ax.plot(range(len(labels)), probabilities, marker='o', label=format_name, alpha=0.7)
234
-
235
- ax.set_xlabel('Response Option')
236
- ax.set_ylabel('Probability')
237
- ax.set_title(f'{metric_name}', fontweight='bold')
238
  ax.set_xticks(range(5))
239
- ax.set_xticklabels(['Opt 1', 'Opt 2', 'Opt 3', 'Opt 4', 'Opt 5'])
240
- ax.legend(fontsize=7, loc='best')
241
  ax.grid(True, alpha=0.3)
242
-
243
- plt.suptitle(f'Metric Comparison Across Formats\nStatement: {statement[:80]}{"..." if len(statement) > 80 else ""}',
244
- fontsize=12, fontweight='bold')
 
 
 
 
245
  plt.tight_layout()
246
  return fig
247
 
248
 
249
- def analyze_all_formats(statement, persona="", selected_formats=None, metric='geometric_mean'):
250
- """Analyze statement with all selected answer formats"""
 
 
251
  try:
252
- # Read default prompt template
253
- with open("default-prompt.txt", "r") as f:
254
- default_prompt_template = f.read().strip()
255
-
256
- if selected_formats is None or len(selected_formats) == 0:
 
257
  selected_formats = list(ANSWER_FORMATS.keys())
258
-
 
259
  all_results = {}
260
  detailed_output = []
261
-
262
- detailed_output.append("="*80)
263
- detailed_output.append("MULTI-TOKEN PROBABILITY ANALYSIS")
264
- detailed_output.append("="*80)
265
- detailed_output.append(f"Metric used for comparison: {metric}")
 
266
  detailed_output.append("")
267
- detailed_output.append("Methodology:")
268
- detailed_output.append("- Joint Prob: P(token1) × P(token2|token1) × ... × P(tokenN|prev)")
269
- detailed_output.append("- Geometric Mean: (Joint Prob)^(1/N) - length-normalized comparison")
270
- detailed_output.append("- First Token: P(token1) - only the initial token probability")
271
- detailed_output.append("- Avg Token: Arithmetic mean of all token probabilities")
272
- detailed_output.append("="*80)
 
273
  detailed_output.append("")
274
-
275
  for format_name in selected_formats:
276
- format_config = ANSWER_FORMATS[format_name]
277
- options = format_config['options']
278
- labels = format_config['labels']
279
- prompt_suffix = format_config['prompt_suffix']
280
-
281
- # Get token information
282
  token_info = get_token_info(options)
283
-
284
- # Create prompt with format-specific instructions
285
  full_prompt = default_prompt_template.format(statement=statement.strip())
286
  full_prompt += f"\n\n{prompt_suffix}"
287
-
288
- # Create chat messages
289
  messages = []
290
- if persona.strip():
291
  messages.append({"role": "system", "content": persona.strip()})
292
  messages.append({"role": "user", "content": full_prompt})
293
-
294
- # Apply chat template
295
- prompt = tokenizer.apply_chat_template(
296
- messages,
297
- tokenize=False,
298
- add_generation_prompt=True
299
- )
300
-
301
- # Tokenize prompt
302
- prompt_ids = tokenizer(prompt, return_tensors="pt")['input_ids']
303
-
304
- # Calculate probabilities for each option
305
- option_probs = []
306
  for info in token_info:
307
- prob_metrics = calculate_sequence_probability(prompt_ids, info['tokens'])
308
- option_probs.append(prob_metrics)
309
-
310
- # Normalize probabilities so they sum to 1
311
- total_prob = sum([p[metric] for p in option_probs])
312
- normalized_probs = []
313
- for p in option_probs:
314
- normalized = {}
315
- for key in p.keys():
316
- if key == 'token_probs' or key == 'perplexity':
317
- normalized[key] = p[key]
318
- else:
319
- normalized[key] = p[key] / total_prob if total_prob > 0 else 0.0
320
- normalized_probs.append(normalized)
321
-
322
  all_results[format_name] = {
323
- 'labels': labels,
324
- 'probabilities': normalized_probs,
325
- 'options': options,
326
- 'token_info': token_info
 
327
  }
328
-
329
- # Add to detailed output
330
- detailed_output.append(f"\n{'='*80}")
331
  detailed_output.append(f"Format: {format_name}")
332
- detailed_output.append(f"{'='*80}")
333
-
334
- for i, (option, label, info, prob) in enumerate(zip(options, labels, token_info, normalized_probs)):
335
- detailed_output.append(f"\n{label} ({option}):")
336
- detailed_output.append(f" Tokens: {info['token_count']} - {info['decoded_tokens']}")
337
- detailed_output.append(f" Joint Probability: {prob['joint_prob']:.6f}")
338
- detailed_output.append(f" Geometric Mean: {prob['geometric_mean']:.6f}")
339
- detailed_output.append(f" First Token Prob: {prob['first_token_prob']:.6f}")
340
- detailed_output.append(f" Avg Token Prob: {prob['avg_prob']:.6f}")
341
- detailed_output.append(f" Perplexity: {prob['perplexity']:.4f}")
342
-
343
- # Create plots
344
- comparison_plot = create_comparison_plot(all_results, statement, metric)
345
- heatmap_plot = create_heatmap(all_results, metric)
 
 
 
346
  metric_comparison = create_metric_comparison_plot(all_results, statement)
347
-
348
- detailed_text = "\n".join(detailed_output)
349
-
350
- return comparison_plot, heatmap_plot, metric_comparison, detailed_text, "✅ Analysis complete"
351
-
352
  except Exception as e:
353
- import traceback
354
  error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
355
  return None, None, None, "", error_msg
356
 
357
 
358
- # Create Gradio interface
359
- with gr.Blocks(title="The Unsampled Truth - Multi-Token Analysis") as demo:
360
- gr.Markdown("""
361
- # The Unsampled Truth - Multi-Token Probability Analysis
362
-
363
- **ACL-Ready Methodology**: This tool calculates joint probabilities for multi-token responses and provides
364
- multiple normalization metrics for fair comparison across different answer formats.
365
-
366
- **Key Features**:
367
- - Handles both single and multi-token answer options
368
- - Calculates joint probabilities: P(token₁, token₂, ..., tokenₙ)
369
- - Provides length-normalized comparison via geometric mean
370
- - Reports multiple metrics for robustness analysis
371
- """)
372
-
 
373
  with gr.Row():
374
  with gr.Column():
375
  statement_input = gr.Textbox(
376
- label="Statement to Analyze",
377
  placeholder="e.g., Climate change is a serious threat",
378
- lines=3
379
  )
380
  persona_input = gr.Textbox(
381
- label="Persona (Optional)",
382
- placeholder="e.g., You are a conservative voter from rural America",
383
- lines=2
384
  )
385
  format_selector = gr.CheckboxGroup(
386
  choices=list(ANSWER_FORMATS.keys()),
387
  value=list(ANSWER_FORMATS.keys()),
388
  label="Select Answer Formats to Compare",
389
- interactive=True
390
  )
391
  metric_selector = gr.Radio(
392
  choices=[
393
  ("Geometric Mean (Recommended)", "geometric_mean"),
394
  ("Joint Probability", "joint_prob"),
395
  ("First Token Only", "first_token_prob"),
396
- ("Average Token Probability", "avg_prob")
397
  ],
398
  value="geometric_mean",
399
- label="Comparison Metric",
400
- info="Geometric mean normalizes for sequence length - recommended for comparing across formats"
401
  )
402
  analyze_btn = gr.Button("Analyze All Formats", variant="primary")
403
-
404
  with gr.Row():
405
  with gr.Column():
406
- comparison_plot = gr.Plot(label="Format Comparison")
407
  with gr.Column():
408
- heatmap_plot = gr.Plot(label="Probability Heatmap")
409
-
410
  with gr.Row():
411
- metric_comparison = gr.Plot(label="Metric Comparison")
412
-
413
  with gr.Row():
414
- detailed_output = gr.Textbox(label="Detailed Probabilities & Token Analysis", lines=25)
415
  status_output = gr.Textbox(label="Status", lines=2)
416
-
417
- # Examples
418
  gr.Examples(
419
  examples=[
420
- ["Climate change is a serious threat", "", None, "geometric_mean"],
421
- ["Immigration has positive economic effects", "", None, "geometric_mean"],
422
- ["Government should provide universal healthcare", "", None, "geometric_mean"],
423
- ["Artificial intelligence will benefit humanity", "You are a tech entrepreneur", None, "geometric_mean"],
424
- ["Traditional family values are important", "You are a progressive activist", None, "first_token_prob"]
425
  ],
426
- inputs=[statement_input, persona_input, format_selector, metric_selector]
427
  )
428
-
429
  analyze_btn.click(
430
  fn=analyze_all_formats,
431
  inputs=[statement_input, persona_input, format_selector, metric_selector],
432
- outputs=[comparison_plot, heatmap_plot, metric_comparison, detailed_output, status_output]
433
  )
434
 
435
  if __name__ == "__main__":
436
- demo.launch()
 
1
+ import os
2
+ import math
3
+ import traceback
4
  import gradio as gr
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
 
10
+ # =========================
11
+ # Model init
12
+ # =========================
13
  MODEL_NAME = "microsoft/Phi-4-mini-instruct"
14
 
15
  print("Loading model...")
 
18
  MODEL_NAME,
19
  torch_dtype="auto",
20
  device_map="auto",
21
+ trust_remote_code=True,
22
  )
23
+ model.eval()
24
 
25
+ # =========================
26
+ # Answer format configs
27
+ # =========================
28
  ANSWER_FORMATS = {
29
  "1-5 (numeric)": {
30
  "options": ["1", "2", "3", "4", "5"],
 
74
  }
75
 
76
 
77
+ # =========================
78
+ # Helpers
79
+ # =========================
80
+ def safe_read_default_prompt(path="default-prompt.txt"):
81
+ fallback = (
82
+ "You will be given a statement.\n"
83
+ "Answer it according to your best judgment.\n\n"
84
+ "Statement: {statement}\n"
85
+ "Answer:"
86
+ )
87
+ try:
88
+ with open(path, "r", encoding="utf-8") as f:
89
+ txt = f.read().strip()
90
+ if "{statement}" not in txt:
91
+ # ensure it is usable as a format string
92
+ return txt + "\n\nStatement: {statement}\nAnswer:"
93
+ return txt
94
+ except FileNotFoundError:
95
+ return fallback
96
+
97
+
98
  def get_token_info(options):
 
99
  token_info = []
100
  for i, option in enumerate(options):
101
  tokens = tokenizer.encode(option, add_special_tokens=False)
102
  token_info.append({
103
+ "index": i,
104
+ "option": option,
105
+ "tokens": tokens,
106
+ "token_count": len(tokens),
107
+ "decoded_tokens": [tokenizer.decode([t]) for t in tokens],
108
  })
109
  return token_info
110
 
111
 
112
+ def calculate_sequence_metrics(prompt_ids: torch.Tensor, option_tokens, temperature=1.0):
113
  """
114
+ Compute RAW sequence metrics in log-space for stability.
115
+
116
+ Returns RAW:
117
+ - joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity
118
+ - token_probs, sum_logp, mean_logp, n_tokens
 
 
119
  """
120
+ if not option_tokens:
121
  return None
122
+
123
+ device = model.device
124
+ current_input = prompt_ids.to(device)
125
+
126
+ logps = []
127
  token_probs = []
128
+
129
+ for tok in option_tokens:
130
+ tok = int(tok)
 
131
  with torch.no_grad():
132
  outputs = model(current_input)
133
+ logits = outputs.logits[0, -1, :] / float(temperature)
134
+ log_probs = torch.log_softmax(logits, dim=-1)
135
+
136
+ lp = float(log_probs[tok].item())
137
+ p = math.exp(lp)
138
+
139
+ logps.append(lp)
140
+ token_probs.append(p)
141
+
142
+ next_tok = torch.tensor([[tok]], device=device, dtype=torch.long)
143
+ current_input = torch.cat([current_input, next_tok], dim=1)
144
+
145
+ n = len(option_tokens)
146
+ sum_logp = float(np.sum(logps))
147
+ mean_logp = sum_logp / n
148
+
149
+ joint_prob = math.exp(sum_logp) # can be tiny
150
+ geometric_mean = math.exp(mean_logp) # in (0, 1]
151
+ first_token_prob = token_probs[0]
152
+ avg_prob = float(np.mean(token_probs))
153
+ perplexity = math.exp(-mean_logp) # = 1 / geometric_mean
154
+
155
  return {
156
+ "joint_prob": joint_prob,
157
+ "geometric_mean": geometric_mean,
158
+ "first_token_prob": first_token_prob,
159
+ "avg_prob": avg_prob,
160
+ "token_probs": token_probs,
161
+ "perplexity": perplexity,
162
+ "sum_logp": sum_logp,
163
+ "mean_logp": mean_logp,
164
+ "n_tokens": n,
165
  }
166
 
167
 
168
+ def normalized_distribution(option_metrics, metric="geometric_mean", mode="softmax", eps=1e-12):
169
+ """
170
+ Return a normalized distribution over options WITHOUT overwriting raw metrics.
171
+
172
+ Recommended: mode="softmax" in log-space.
173
+ """
174
+ if mode not in ("softmax", "simple"):
175
+ raise ValueError("mode must be 'softmax' or 'simple'")
176
+
177
+ if metric == "joint_prob":
178
+ scores = np.array([m["sum_logp"] for m in option_metrics], dtype=np.float64)
179
+ elif metric == "geometric_mean":
180
+ scores = np.array([m["mean_logp"] for m in option_metrics], dtype=np.float64)
181
+ elif metric == "first_token_prob":
182
+ scores = np.log(np.array([max(m["first_token_prob"], eps) for m in option_metrics], dtype=np.float64))
183
+ elif metric == "avg_prob":
184
+ scores = np.log(np.array([max(m["avg_prob"], eps) for m in option_metrics], dtype=np.float64))
185
+ else:
186
+ raise ValueError(f"Unknown metric: {metric}")
187
+
188
+ if mode == "simple":
189
+ raw = np.array([max(m[metric], eps) for m in option_metrics], dtype=np.float64)
190
+ s = raw.sum()
191
+ return (raw / s).tolist() if s > 0 else [0.0] * len(option_metrics)
192
+
193
+ # softmax(scores)
194
+ scores = scores - scores.max()
195
+ exps = np.exp(scores)
196
+ s = exps.sum()
197
+ return (exps / s).tolist() if s > 0 else [0.0] * len(option_metrics)
198
+
199
+
200
+ # =========================
201
+ # Plotting
202
+ # =========================
203
+ def create_comparison_plot(all_results, statement, metric="geometric_mean"):
204
+ """Bar plots of normalized option-mass per format for the selected metric."""
205
  n_formats = len(all_results)
206
  if n_formats == 0:
207
  return None
208
+
209
+ ncols = (n_formats + 1) // 2
210
+ fig, axes = plt.subplots(2, ncols, figsize=(16, 8))
211
+ axes = np.array(axes).flatten()
212
+
 
 
213
  metric_names = {
214
+ "geometric_mean": "Softmax over mean log-prob (Recommended)",
215
+ "joint_prob": "Softmax over joint log-prob",
216
+ "first_token_prob": "Softmax over log first-token prob",
217
+ "avg_prob": "Softmax over log avg-token prob",
218
  }
219
+
220
  for idx, (format_name, data) in enumerate(all_results.items()):
221
  ax = axes[idx]
222
+ labels = data["labels"]
223
+ dist = data["norm_dists"][metric]
224
+
225
+ bars = ax.bar(range(len(labels)), dist, alpha=0.85, edgecolor="black")
226
+ for bar, p in zip(bars, dist):
 
 
227
  height = bar.get_height()
228
+ ax.text(
229
+ bar.get_x() + bar.get_width() / 2.0,
230
+ height + 0.01,
231
+ f"{p:.3f}",
232
+ ha="center",
233
+ va="bottom",
234
+ fontsize=8,
235
+ )
236
+
237
+ ax.set_ylabel("Normalized option mass", fontsize=9)
238
+ ax.set_title(format_name, fontsize=10, fontweight="bold")
239
  ax.set_xticks(range(len(labels)))
240
+ ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
241
+ ax.set_ylim(0, max(dist) * 1.2 if max(dist) > 0 else 1.0)
242
+ ax.grid(True, axis="y", alpha=0.3)
243
+
244
+ # hide unused subplots
245
+ for k in range(n_formats, len(axes)):
246
+ axes[k].set_visible(False)
247
+
248
+ plt.suptitle(
249
+ f"Response Distribution Comparison\nMetric: {metric_names.get(metric, metric)}\n"
250
+ f"Statement: {statement[:80]}{'...' if len(statement) > 80 else ''}",
251
+ fontsize=12,
252
+ fontweight="bold",
253
+ )
254
  plt.tight_layout()
255
  return fig
256
 
257
 
258
+ def create_heatmap(all_results, metric="geometric_mean"):
259
+ """Heatmap of normalized option-mass per format."""
260
  format_names = list(all_results.keys())
261
+ if not format_names:
262
  return None
263
+
264
  n_options = 5
265
+ prob_matrix = np.zeros((len(format_names), n_options), dtype=np.float64)
266
+
267
+ for i, fmt in enumerate(format_names):
268
+ prob_matrix[i] = all_results[fmt]["norm_dists"][metric]
269
+
270
  fig, ax = plt.subplots(figsize=(10, 8))
271
+ im = ax.imshow(prob_matrix, aspect="auto", vmin=0, vmax=float(np.max(prob_matrix)) if prob_matrix.size else 1.0)
272
+
273
  ax.set_xticks(range(n_options))
274
+ ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"])
275
  ax.set_yticks(range(len(format_names)))
276
  ax.set_yticklabels(format_names, fontsize=9)
277
+
278
+ for i in range(prob_matrix.shape[0]):
279
+ for j in range(prob_matrix.shape[1]):
280
+ ax.text(j, i, f"{prob_matrix[i, j]:.3f}", ha="center", va="center", fontsize=8)
281
+
 
 
282
  metric_names = {
283
+ "geometric_mean": "mean log-prob softmax",
284
+ "joint_prob": "joint log-prob softmax",
285
+ "first_token_prob": "first-token softmax",
286
+ "avg_prob": "avg-token softmax",
287
  }
288
+
289
+ ax.set_title(f"Probability Heatmap (Normalized)\nMetric: {metric_names.get(metric, metric)}", fontsize=12, fontweight="bold")
290
+ plt.colorbar(im, ax=ax, label="Normalized option mass")
291
  plt.tight_layout()
292
  return fig
293
 
294
 
295
  def create_metric_comparison_plot(all_results, statement):
296
+ """
297
+ Compares normalized distributions under four metrics.
298
+ Each subplot: per-format line over option index for the given metric.
299
+ """
300
+ metrics = ["geometric_mean", "joint_prob", "first_token_prob", "avg_prob"]
301
+ metric_titles = ["Geometric Mean (log) softmax", "Joint (log) softmax", "First-token (log) softmax", "Avg-token (log) softmax"]
302
+
303
+ if not all_results:
304
  return None
305
+
306
  fig, axes = plt.subplots(2, 2, figsize=(14, 10))
307
+ axes = np.array(axes).flatten()
308
+
309
+ for ax, metric, title in zip(axes, metrics, metric_titles):
 
 
310
  for format_name, data in all_results.items():
311
+ dist = data["norm_dists"][metric]
312
+ ax.plot(range(5), dist, marker="o", label=format_name, alpha=0.75)
313
+
314
+ ax.set_xlabel("Response option index")
315
+ ax.set_ylabel("Normalized option mass")
316
+ ax.set_title(title, fontweight="bold")
 
317
  ax.set_xticks(range(5))
318
+ ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"])
 
319
  ax.grid(True, alpha=0.3)
320
+ ax.legend(fontsize=7, loc="best")
321
+
322
+ plt.suptitle(
323
+ f"Metric Comparison (Normalized Distributions)\nStatement: {statement[:80]}{'...' if len(statement) > 80 else ''}",
324
+ fontsize=12,
325
+ fontweight="bold",
326
+ )
327
  plt.tight_layout()
328
  return fig
329
 
330
 
331
+ # =========================
332
+ # Core analysis
333
+ # =========================
334
+ def analyze_all_formats(statement, persona="", selected_formats=None, metric="geometric_mean"):
335
  try:
336
+ default_prompt_template = safe_read_default_prompt()
337
+
338
+ if not statement or not statement.strip():
339
+ return None, None, None, "", "❌ Please enter a statement."
340
+
341
+ if not selected_formats:
342
  selected_formats = list(ANSWER_FORMATS.keys())
343
+
344
+ # Build results container
345
  all_results = {}
346
  detailed_output = []
347
+
348
+ detailed_output.append("=" * 80)
349
+ detailed_output.append("MULTI-TOKEN RAW PROBABILITY ANALYSIS (FIXED)")
350
+ detailed_output.append("=" * 80)
351
+ detailed_output.append("Raw metrics are NOT normalized (true probabilities).")
352
+ detailed_output.append("Plots use a SEPARATE normalized distribution per metric (softmax in log-space).")
353
  detailed_output.append("")
354
+ detailed_output.append("Raw metrics:")
355
+ detailed_output.append("- joint_prob: exp(sum log p_i)")
356
+ detailed_output.append("- geometric_mean: exp(mean log p_i) (length-normalized likelihood)")
357
+ detailed_output.append("- perplexity: exp(-mean log p_i) = 1 / geometric_mean")
358
+ detailed_output.append("- first_token_prob: p_1")
359
+ detailed_output.append("- avg_prob: mean(p_i)")
360
+ detailed_output.append("=" * 80)
361
  detailed_output.append("")
362
+
363
  for format_name in selected_formats:
364
+ cfg = ANSWER_FORMATS[format_name]
365
+ options = cfg["options"]
366
+ labels = cfg["labels"]
367
+ prompt_suffix = cfg["prompt_suffix"]
368
+
 
369
  token_info = get_token_info(options)
370
+
 
371
  full_prompt = default_prompt_template.format(statement=statement.strip())
372
  full_prompt += f"\n\n{prompt_suffix}"
373
+
 
374
  messages = []
375
+ if persona and persona.strip():
376
  messages.append({"role": "system", "content": persona.strip()})
377
  messages.append({"role": "user", "content": full_prompt})
378
+
379
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
380
+ prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(dtype=torch.long)
381
+
382
+ # Compute RAW metrics per option
383
+ raw_metrics = []
 
 
 
 
 
 
 
384
  for info in token_info:
385
+ m = calculate_sequence_metrics(prompt_ids, info["tokens"])
386
+ raw_metrics.append(m)
387
+
388
+ # Compute normalized distributions for all metrics (for plotting)
389
+ norm_dists = {
390
+ "geometric_mean": normalized_distribution(raw_metrics, metric="geometric_mean", mode="softmax"),
391
+ "joint_prob": normalized_distribution(raw_metrics, metric="joint_prob", mode="softmax"),
392
+ "first_token_prob": normalized_distribution(raw_metrics, metric="first_token_prob", mode="softmax"),
393
+ "avg_prob": normalized_distribution(raw_metrics, metric="avg_prob", mode="softmax"),
394
+ }
395
+
 
 
 
 
396
  all_results[format_name] = {
397
+ "labels": labels,
398
+ "options": options,
399
+ "token_info": token_info,
400
+ "raw_metrics": raw_metrics,
401
+ "norm_dists": norm_dists,
402
  }
403
+
404
+ # Detailed output (RAW + selected-metric normalized mass)
405
+ detailed_output.append(f"\n{'=' * 80}")
406
  detailed_output.append(f"Format: {format_name}")
407
+ detailed_output.append(f"{'=' * 80}")
408
+
409
+ selected_norm = norm_dists[metric]
410
+
411
+ for opt, lab, info, m, nmass in zip(options, labels, token_info, raw_metrics, selected_norm):
412
+ detailed_output.append(f"\n{lab} ({opt}):")
413
+ detailed_output.append(f" Tokens ({info['token_count']}): {info['decoded_tokens']}")
414
+ detailed_output.append(f" RAW joint_prob: {m['joint_prob']:.6e}")
415
+ detailed_output.append(f" RAW geometric_mean: {m['geometric_mean']:.6e}")
416
+ detailed_output.append(f" RAW first_token_prob: {m['first_token_prob']:.6e}")
417
+ detailed_output.append(f" RAW avg_prob: {m['avg_prob']:.6e}")
418
+ detailed_output.append(f" RAW perplexity: {m['perplexity']:.4f}")
419
+ detailed_output.append(f" NORM({metric}) mass: {nmass:.4f}")
420
+
421
+ # Plots (normalized distributions)
422
+ comparison_plot = create_comparison_plot(all_results, statement, metric=metric)
423
+ heatmap_plot = create_heatmap(all_results, metric=metric)
424
  metric_comparison = create_metric_comparison_plot(all_results, statement)
425
+
426
+ return comparison_plot, heatmap_plot, metric_comparison, "\n".join(detailed_output), "✅ Analysis complete"
427
+
 
 
428
  except Exception as e:
 
429
  error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
430
  return None, None, None, "", error_msg
431
 
432
 
433
+ # =========================
434
+ # Gradio UI
435
+ # =========================
436
+ with gr.Blocks(title="The Unsampled Truth - Multi-Token Analysis (Fixed)") as demo:
437
+ gr.Markdown(
438
+ """
439
+ # The Unsampled Truth Multi-Token Probability Analysis (Fixed)
440
+
441
+ This tool computes **RAW** multi-token likelihood metrics per option and plots **normalized** option distributions
442
+ using **softmax in log-space** (so values stay valid and comparable).
443
+
444
+ - RAW metrics: joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity
445
+ - Plots: normalized option mass under the selected metric
446
+ """
447
+ )
448
+
449
  with gr.Row():
450
  with gr.Column():
451
  statement_input = gr.Textbox(
452
+ label="Statement to Analyze",
453
  placeholder="e.g., Climate change is a serious threat",
454
+ lines=3,
455
  )
456
  persona_input = gr.Textbox(
457
+ label="Persona (Optional)",
458
+ placeholder="e.g., You are a tech entrepreneur",
459
+ lines=2,
460
  )
461
  format_selector = gr.CheckboxGroup(
462
  choices=list(ANSWER_FORMATS.keys()),
463
  value=list(ANSWER_FORMATS.keys()),
464
  label="Select Answer Formats to Compare",
465
+ interactive=True,
466
  )
467
  metric_selector = gr.Radio(
468
  choices=[
469
  ("Geometric Mean (Recommended)", "geometric_mean"),
470
  ("Joint Probability", "joint_prob"),
471
  ("First Token Only", "first_token_prob"),
472
+ ("Average Token Probability", "avg_prob"),
473
  ],
474
  value="geometric_mean",
475
+ label="Comparison Metric (for plots + NORM mass line)",
 
476
  )
477
  analyze_btn = gr.Button("Analyze All Formats", variant="primary")
478
+
479
  with gr.Row():
480
  with gr.Column():
481
+ comparison_plot = gr.Plot(label="Format Comparison (Normalized)")
482
  with gr.Column():
483
+ heatmap_plot = gr.Plot(label="Heatmap (Normalized)")
484
+
485
  with gr.Row():
486
+ metric_comparison = gr.Plot(label="Metric Comparison (Normalized)")
487
+
488
  with gr.Row():
489
+ detailed_output = gr.Textbox(label="Detailed Output (RAW metrics + normalized mass)", lines=25)
490
  status_output = gr.Textbox(label="Status", lines=2)
491
+
 
492
  gr.Examples(
493
  examples=[
494
+ ["Climate change is a serious threat", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
495
+ ["Immigration has positive economic effects", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
496
+ ["Government should provide universal healthcare", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
497
+ ["Artificial intelligence will benefit humanity", "You are a tech entrepreneur", list(ANSWER_FORMATS.keys()), "geometric_mean"],
498
+ ["Traditional family values are important", "You are a progressive activist", list(ANSWER_FORMATS.keys()), "first_token_prob"],
499
  ],
500
+ inputs=[statement_input, persona_input, format_selector, metric_selector],
501
  )
502
+
503
  analyze_btn.click(
504
  fn=analyze_all_formats,
505
  inputs=[statement_input, persona_input, format_selector, metric_selector],
506
+ outputs=[comparison_plot, heatmap_plot, metric_comparison, detailed_output, status_output],
507
  )
508
 
509
  if __name__ == "__main__":
510
+ demo.launch()