Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
-
#
|
|
|
|
|
|
|
| 8 |
MODEL_NAME = "microsoft/Phi-4-mini-instruct"
|
| 9 |
|
| 10 |
print("Loading model...")
|
|
@@ -13,10 +18,13 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 13 |
MODEL_NAME,
|
| 14 |
torch_dtype="auto",
|
| 15 |
device_map="auto",
|
| 16 |
-
trust_remote_code=True
|
| 17 |
)
|
|
|
|
| 18 |
|
| 19 |
-
#
|
|
|
|
|
|
|
| 20 |
ANSWER_FORMATS = {
|
| 21 |
"1-5 (numeric)": {
|
| 22 |
"options": ["1", "2", "3", "4", "5"],
|
|
@@ -66,371 +74,437 @@ ANSWER_FORMATS = {
|
|
| 66 |
}
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def get_token_info(options):
|
| 70 |
-
"""Get detailed token information for each option"""
|
| 71 |
token_info = []
|
| 72 |
for i, option in enumerate(options):
|
| 73 |
tokens = tokenizer.encode(option, add_special_tokens=False)
|
| 74 |
token_info.append({
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
})
|
| 81 |
return token_info
|
| 82 |
|
| 83 |
|
| 84 |
-
def
|
| 85 |
"""
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
Returns
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
- first_token_prob: P(token1) - for comparison with single-token options
|
| 92 |
-
- avg_prob: arithmetic mean of individual token probabilities
|
| 93 |
"""
|
| 94 |
-
if
|
| 95 |
return None
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
token_probs = []
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
for token_id in option_tokens:
|
| 103 |
with torch.no_grad():
|
| 104 |
outputs = model(current_input)
|
| 105 |
-
logits = outputs.logits[0, -1, :] / temperature
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
return {
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
| 124 |
}
|
| 125 |
|
| 126 |
|
| 127 |
-
def
|
| 128 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
n_formats = len(all_results)
|
| 130 |
if n_formats == 0:
|
| 131 |
return None
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
axes = axes.flatten()
|
| 138 |
-
|
| 139 |
metric_names = {
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
}
|
| 145 |
-
|
| 146 |
for idx, (format_name, data) in enumerate(all_results.items()):
|
| 147 |
ax = axes[idx]
|
| 148 |
-
labels = data[
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
bars = ax.bar(range(len(labels)),
|
| 152 |
-
|
| 153 |
-
# Add value labels
|
| 154 |
-
for bar, prob in zip(bars, probabilities):
|
| 155 |
height = bar.get_height()
|
| 156 |
-
ax.text(
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
ax.set_xticks(range(len(labels)))
|
| 162 |
-
ax.set_xticklabels(labels, rotation=45, ha=
|
| 163 |
-
ax.set_ylim(0, max(
|
| 164 |
-
ax.grid(True, axis=
|
| 165 |
-
|
| 166 |
-
#
|
| 167 |
-
for
|
| 168 |
-
axes[
|
| 169 |
-
|
| 170 |
-
plt.suptitle(
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
plt.tight_layout()
|
| 173 |
return fig
|
| 174 |
|
| 175 |
|
| 176 |
-
def create_heatmap(all_results, metric=
|
| 177 |
-
"""
|
| 178 |
format_names = list(all_results.keys())
|
| 179 |
-
if
|
| 180 |
return None
|
| 181 |
-
|
| 182 |
n_options = 5
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
for i,
|
| 186 |
-
prob_matrix[i] = [
|
| 187 |
-
|
| 188 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 189 |
-
im = ax.imshow(prob_matrix,
|
| 190 |
-
|
| 191 |
ax.set_xticks(range(n_options))
|
| 192 |
-
ax.set_xticklabels([
|
| 193 |
ax.set_yticks(range(len(format_names)))
|
| 194 |
ax.set_yticklabels(format_names, fontsize=9)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
ha="center", va="center", color="black", fontsize=8)
|
| 201 |
-
|
| 202 |
metric_names = {
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
}
|
| 208 |
-
|
| 209 |
-
ax.set_title(f
|
| 210 |
-
plt.colorbar(im, ax=ax, label=
|
| 211 |
plt.tight_layout()
|
| 212 |
return fig
|
| 213 |
|
| 214 |
|
| 215 |
def create_metric_comparison_plot(all_results, statement):
|
| 216 |
-
"""
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
| 222 |
return None
|
| 223 |
-
|
| 224 |
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 225 |
-
axes = axes.flatten()
|
| 226 |
-
|
| 227 |
-
for
|
| 228 |
-
ax = axes[metric_idx]
|
| 229 |
-
|
| 230 |
for format_name, data in all_results.items():
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
ax.
|
| 236 |
-
ax.
|
| 237 |
-
ax.set_title(f'{metric_name}', fontweight='bold')
|
| 238 |
ax.set_xticks(range(5))
|
| 239 |
-
ax.set_xticklabels([
|
| 240 |
-
ax.legend(fontsize=7, loc='best')
|
| 241 |
ax.grid(True, alpha=0.3)
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
plt.tight_layout()
|
| 246 |
return fig
|
| 247 |
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
| 251 |
try:
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
| 257 |
selected_formats = list(ANSWER_FORMATS.keys())
|
| 258 |
-
|
|
|
|
| 259 |
all_results = {}
|
| 260 |
detailed_output = []
|
| 261 |
-
|
| 262 |
-
detailed_output.append("="*80)
|
| 263 |
-
detailed_output.append("MULTI-TOKEN PROBABILITY ANALYSIS")
|
| 264 |
-
detailed_output.append("="*80)
|
| 265 |
-
detailed_output.append(
|
|
|
|
| 266 |
detailed_output.append("")
|
| 267 |
-
detailed_output.append("
|
| 268 |
-
detailed_output.append("-
|
| 269 |
-
detailed_output.append("-
|
| 270 |
-
detailed_output.append("-
|
| 271 |
-
detailed_output.append("-
|
| 272 |
-
detailed_output.append("
|
|
|
|
| 273 |
detailed_output.append("")
|
| 274 |
-
|
| 275 |
for format_name in selected_formats:
|
| 276 |
-
|
| 277 |
-
options =
|
| 278 |
-
labels =
|
| 279 |
-
prompt_suffix =
|
| 280 |
-
|
| 281 |
-
# Get token information
|
| 282 |
token_info = get_token_info(options)
|
| 283 |
-
|
| 284 |
-
# Create prompt with format-specific instructions
|
| 285 |
full_prompt = default_prompt_template.format(statement=statement.strip())
|
| 286 |
full_prompt += f"\n\n{prompt_suffix}"
|
| 287 |
-
|
| 288 |
-
# Create chat messages
|
| 289 |
messages = []
|
| 290 |
-
if persona.strip():
|
| 291 |
messages.append({"role": "system", "content": persona.strip()})
|
| 292 |
messages.append({"role": "user", "content": full_prompt})
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
)
|
| 300 |
-
|
| 301 |
-
# Tokenize prompt
|
| 302 |
-
prompt_ids = tokenizer(prompt, return_tensors="pt")['input_ids']
|
| 303 |
-
|
| 304 |
-
# Calculate probabilities for each option
|
| 305 |
-
option_probs = []
|
| 306 |
for info in token_info:
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
#
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
else:
|
| 319 |
-
normalized[key] = p[key] / total_prob if total_prob > 0 else 0.0
|
| 320 |
-
normalized_probs.append(normalized)
|
| 321 |
-
|
| 322 |
all_results[format_name] = {
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
| 327 |
}
|
| 328 |
-
|
| 329 |
-
#
|
| 330 |
-
detailed_output.append(f"\n{'='*80}")
|
| 331 |
detailed_output.append(f"Format: {format_name}")
|
| 332 |
-
detailed_output.append(f"{'='*80}")
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
detailed_output.append(f"
|
| 338 |
-
detailed_output.append(f"
|
| 339 |
-
detailed_output.append(f"
|
| 340 |
-
detailed_output.append(f"
|
| 341 |
-
detailed_output.append(f"
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
| 346 |
metric_comparison = create_metric_comparison_plot(all_results, statement)
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
return comparison_plot, heatmap_plot, metric_comparison, detailed_text, "✅ Analysis complete"
|
| 351 |
-
|
| 352 |
except Exception as e:
|
| 353 |
-
import traceback
|
| 354 |
error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
|
| 355 |
return None, None, None, "", error_msg
|
| 356 |
|
| 357 |
|
| 358 |
-
#
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
| 373 |
with gr.Row():
|
| 374 |
with gr.Column():
|
| 375 |
statement_input = gr.Textbox(
|
| 376 |
-
label="Statement to Analyze",
|
| 377 |
placeholder="e.g., Climate change is a serious threat",
|
| 378 |
-
lines=3
|
| 379 |
)
|
| 380 |
persona_input = gr.Textbox(
|
| 381 |
-
label="Persona (Optional)",
|
| 382 |
-
placeholder="e.g., You are a
|
| 383 |
-
lines=2
|
| 384 |
)
|
| 385 |
format_selector = gr.CheckboxGroup(
|
| 386 |
choices=list(ANSWER_FORMATS.keys()),
|
| 387 |
value=list(ANSWER_FORMATS.keys()),
|
| 388 |
label="Select Answer Formats to Compare",
|
| 389 |
-
interactive=True
|
| 390 |
)
|
| 391 |
metric_selector = gr.Radio(
|
| 392 |
choices=[
|
| 393 |
("Geometric Mean (Recommended)", "geometric_mean"),
|
| 394 |
("Joint Probability", "joint_prob"),
|
| 395 |
("First Token Only", "first_token_prob"),
|
| 396 |
-
("Average Token Probability", "avg_prob")
|
| 397 |
],
|
| 398 |
value="geometric_mean",
|
| 399 |
-
label="Comparison Metric",
|
| 400 |
-
info="Geometric mean normalizes for sequence length - recommended for comparing across formats"
|
| 401 |
)
|
| 402 |
analyze_btn = gr.Button("Analyze All Formats", variant="primary")
|
| 403 |
-
|
| 404 |
with gr.Row():
|
| 405 |
with gr.Column():
|
| 406 |
-
comparison_plot = gr.Plot(label="Format Comparison")
|
| 407 |
with gr.Column():
|
| 408 |
-
heatmap_plot = gr.Plot(label="
|
| 409 |
-
|
| 410 |
with gr.Row():
|
| 411 |
-
metric_comparison = gr.Plot(label="Metric Comparison")
|
| 412 |
-
|
| 413 |
with gr.Row():
|
| 414 |
-
detailed_output = gr.Textbox(label="Detailed
|
| 415 |
status_output = gr.Textbox(label="Status", lines=2)
|
| 416 |
-
|
| 417 |
-
# Examples
|
| 418 |
gr.Examples(
|
| 419 |
examples=[
|
| 420 |
-
["Climate change is a serious threat", "",
|
| 421 |
-
["Immigration has positive economic effects", "",
|
| 422 |
-
["Government should provide universal healthcare", "",
|
| 423 |
-
["Artificial intelligence will benefit humanity", "You are a tech entrepreneur",
|
| 424 |
-
["Traditional family values are important", "You are a progressive activist",
|
| 425 |
],
|
| 426 |
-
inputs=[statement_input, persona_input, format_selector, metric_selector]
|
| 427 |
)
|
| 428 |
-
|
| 429 |
analyze_btn.click(
|
| 430 |
fn=analyze_all_formats,
|
| 431 |
inputs=[statement_input, persona_input, format_selector, metric_selector],
|
| 432 |
-
outputs=[comparison_plot, heatmap_plot, metric_comparison, detailed_output, status_output]
|
| 433 |
)
|
| 434 |
|
| 435 |
if __name__ == "__main__":
|
| 436 |
-
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import traceback
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
| 6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
+
# =========================
|
| 11 |
+
# Model init
|
| 12 |
+
# =========================
|
| 13 |
MODEL_NAME = "microsoft/Phi-4-mini-instruct"
|
| 14 |
|
| 15 |
print("Loading model...")
|
|
|
|
| 18 |
MODEL_NAME,
|
| 19 |
torch_dtype="auto",
|
| 20 |
device_map="auto",
|
| 21 |
+
trust_remote_code=True,
|
| 22 |
)
|
| 23 |
+
model.eval()
|
| 24 |
|
| 25 |
+
# =========================
|
| 26 |
+
# Answer format configs
|
| 27 |
+
# =========================
|
| 28 |
ANSWER_FORMATS = {
|
| 29 |
"1-5 (numeric)": {
|
| 30 |
"options": ["1", "2", "3", "4", "5"],
|
|
|
|
| 74 |
}
|
| 75 |
|
| 76 |
|
| 77 |
+
# =========================
|
| 78 |
+
# Helpers
|
| 79 |
+
# =========================
|
| 80 |
+
def safe_read_default_prompt(path="default-prompt.txt"):
|
| 81 |
+
fallback = (
|
| 82 |
+
"You will be given a statement.\n"
|
| 83 |
+
"Answer it according to your best judgment.\n\n"
|
| 84 |
+
"Statement: {statement}\n"
|
| 85 |
+
"Answer:"
|
| 86 |
+
)
|
| 87 |
+
try:
|
| 88 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 89 |
+
txt = f.read().strip()
|
| 90 |
+
if "{statement}" not in txt:
|
| 91 |
+
# ensure it is usable as a format string
|
| 92 |
+
return txt + "\n\nStatement: {statement}\nAnswer:"
|
| 93 |
+
return txt
|
| 94 |
+
except FileNotFoundError:
|
| 95 |
+
return fallback
|
| 96 |
+
|
| 97 |
+
|
| 98 |
def get_token_info(options):
|
|
|
|
| 99 |
token_info = []
|
| 100 |
for i, option in enumerate(options):
|
| 101 |
tokens = tokenizer.encode(option, add_special_tokens=False)
|
| 102 |
token_info.append({
|
| 103 |
+
"index": i,
|
| 104 |
+
"option": option,
|
| 105 |
+
"tokens": tokens,
|
| 106 |
+
"token_count": len(tokens),
|
| 107 |
+
"decoded_tokens": [tokenizer.decode([t]) for t in tokens],
|
| 108 |
})
|
| 109 |
return token_info
|
| 110 |
|
| 111 |
|
| 112 |
+
def calculate_sequence_metrics(prompt_ids: torch.Tensor, option_tokens, temperature=1.0):
|
| 113 |
"""
|
| 114 |
+
Compute RAW sequence metrics in log-space for stability.
|
| 115 |
+
|
| 116 |
+
Returns RAW:
|
| 117 |
+
- joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity
|
| 118 |
+
- token_probs, sum_logp, mean_logp, n_tokens
|
|
|
|
|
|
|
| 119 |
"""
|
| 120 |
+
if not option_tokens:
|
| 121 |
return None
|
| 122 |
+
|
| 123 |
+
device = model.device
|
| 124 |
+
current_input = prompt_ids.to(device)
|
| 125 |
+
|
| 126 |
+
logps = []
|
| 127 |
token_probs = []
|
| 128 |
+
|
| 129 |
+
for tok in option_tokens:
|
| 130 |
+
tok = int(tok)
|
|
|
|
| 131 |
with torch.no_grad():
|
| 132 |
outputs = model(current_input)
|
| 133 |
+
logits = outputs.logits[0, -1, :] / float(temperature)
|
| 134 |
+
log_probs = torch.log_softmax(logits, dim=-1)
|
| 135 |
+
|
| 136 |
+
lp = float(log_probs[tok].item())
|
| 137 |
+
p = math.exp(lp)
|
| 138 |
+
|
| 139 |
+
logps.append(lp)
|
| 140 |
+
token_probs.append(p)
|
| 141 |
+
|
| 142 |
+
next_tok = torch.tensor([[tok]], device=device, dtype=torch.long)
|
| 143 |
+
current_input = torch.cat([current_input, next_tok], dim=1)
|
| 144 |
+
|
| 145 |
+
n = len(option_tokens)
|
| 146 |
+
sum_logp = float(np.sum(logps))
|
| 147 |
+
mean_logp = sum_logp / n
|
| 148 |
+
|
| 149 |
+
joint_prob = math.exp(sum_logp) # can be tiny
|
| 150 |
+
geometric_mean = math.exp(mean_logp) # in (0, 1]
|
| 151 |
+
first_token_prob = token_probs[0]
|
| 152 |
+
avg_prob = float(np.mean(token_probs))
|
| 153 |
+
perplexity = math.exp(-mean_logp) # = 1 / geometric_mean
|
| 154 |
+
|
| 155 |
return {
|
| 156 |
+
"joint_prob": joint_prob,
|
| 157 |
+
"geometric_mean": geometric_mean,
|
| 158 |
+
"first_token_prob": first_token_prob,
|
| 159 |
+
"avg_prob": avg_prob,
|
| 160 |
+
"token_probs": token_probs,
|
| 161 |
+
"perplexity": perplexity,
|
| 162 |
+
"sum_logp": sum_logp,
|
| 163 |
+
"mean_logp": mean_logp,
|
| 164 |
+
"n_tokens": n,
|
| 165 |
}
|
| 166 |
|
| 167 |
|
| 168 |
+
def normalized_distribution(option_metrics, metric="geometric_mean", mode="softmax", eps=1e-12):
|
| 169 |
+
"""
|
| 170 |
+
Return a normalized distribution over options WITHOUT overwriting raw metrics.
|
| 171 |
+
|
| 172 |
+
Recommended: mode="softmax" in log-space.
|
| 173 |
+
"""
|
| 174 |
+
if mode not in ("softmax", "simple"):
|
| 175 |
+
raise ValueError("mode must be 'softmax' or 'simple'")
|
| 176 |
+
|
| 177 |
+
if metric == "joint_prob":
|
| 178 |
+
scores = np.array([m["sum_logp"] for m in option_metrics], dtype=np.float64)
|
| 179 |
+
elif metric == "geometric_mean":
|
| 180 |
+
scores = np.array([m["mean_logp"] for m in option_metrics], dtype=np.float64)
|
| 181 |
+
elif metric == "first_token_prob":
|
| 182 |
+
scores = np.log(np.array([max(m["first_token_prob"], eps) for m in option_metrics], dtype=np.float64))
|
| 183 |
+
elif metric == "avg_prob":
|
| 184 |
+
scores = np.log(np.array([max(m["avg_prob"], eps) for m in option_metrics], dtype=np.float64))
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(f"Unknown metric: {metric}")
|
| 187 |
+
|
| 188 |
+
if mode == "simple":
|
| 189 |
+
raw = np.array([max(m[metric], eps) for m in option_metrics], dtype=np.float64)
|
| 190 |
+
s = raw.sum()
|
| 191 |
+
return (raw / s).tolist() if s > 0 else [0.0] * len(option_metrics)
|
| 192 |
+
|
| 193 |
+
# softmax(scores)
|
| 194 |
+
scores = scores - scores.max()
|
| 195 |
+
exps = np.exp(scores)
|
| 196 |
+
s = exps.sum()
|
| 197 |
+
return (exps / s).tolist() if s > 0 else [0.0] * len(option_metrics)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# =========================
|
| 201 |
+
# Plotting
|
| 202 |
+
# =========================
|
| 203 |
+
def create_comparison_plot(all_results, statement, metric="geometric_mean"):
|
| 204 |
+
"""Bar plots of normalized option-mass per format for the selected metric."""
|
| 205 |
n_formats = len(all_results)
|
| 206 |
if n_formats == 0:
|
| 207 |
return None
|
| 208 |
+
|
| 209 |
+
ncols = (n_formats + 1) // 2
|
| 210 |
+
fig, axes = plt.subplots(2, ncols, figsize=(16, 8))
|
| 211 |
+
axes = np.array(axes).flatten()
|
| 212 |
+
|
|
|
|
|
|
|
| 213 |
metric_names = {
|
| 214 |
+
"geometric_mean": "Softmax over mean log-prob (Recommended)",
|
| 215 |
+
"joint_prob": "Softmax over joint log-prob",
|
| 216 |
+
"first_token_prob": "Softmax over log first-token prob",
|
| 217 |
+
"avg_prob": "Softmax over log avg-token prob",
|
| 218 |
}
|
| 219 |
+
|
| 220 |
for idx, (format_name, data) in enumerate(all_results.items()):
|
| 221 |
ax = axes[idx]
|
| 222 |
+
labels = data["labels"]
|
| 223 |
+
dist = data["norm_dists"][metric]
|
| 224 |
+
|
| 225 |
+
bars = ax.bar(range(len(labels)), dist, alpha=0.85, edgecolor="black")
|
| 226 |
+
for bar, p in zip(bars, dist):
|
|
|
|
|
|
|
| 227 |
height = bar.get_height()
|
| 228 |
+
ax.text(
|
| 229 |
+
bar.get_x() + bar.get_width() / 2.0,
|
| 230 |
+
height + 0.01,
|
| 231 |
+
f"{p:.3f}",
|
| 232 |
+
ha="center",
|
| 233 |
+
va="bottom",
|
| 234 |
+
fontsize=8,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
ax.set_ylabel("Normalized option mass", fontsize=9)
|
| 238 |
+
ax.set_title(format_name, fontsize=10, fontweight="bold")
|
| 239 |
ax.set_xticks(range(len(labels)))
|
| 240 |
+
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
| 241 |
+
ax.set_ylim(0, max(dist) * 1.2 if max(dist) > 0 else 1.0)
|
| 242 |
+
ax.grid(True, axis="y", alpha=0.3)
|
| 243 |
+
|
| 244 |
+
# hide unused subplots
|
| 245 |
+
for k in range(n_formats, len(axes)):
|
| 246 |
+
axes[k].set_visible(False)
|
| 247 |
+
|
| 248 |
+
plt.suptitle(
|
| 249 |
+
f"Response Distribution Comparison\nMetric: {metric_names.get(metric, metric)}\n"
|
| 250 |
+
f"Statement: {statement[:80]}{'...' if len(statement) > 80 else ''}",
|
| 251 |
+
fontsize=12,
|
| 252 |
+
fontweight="bold",
|
| 253 |
+
)
|
| 254 |
plt.tight_layout()
|
| 255 |
return fig
|
| 256 |
|
| 257 |
|
| 258 |
+
def create_heatmap(all_results, metric="geometric_mean"):
|
| 259 |
+
"""Heatmap of normalized option-mass per format."""
|
| 260 |
format_names = list(all_results.keys())
|
| 261 |
+
if not format_names:
|
| 262 |
return None
|
| 263 |
+
|
| 264 |
n_options = 5
|
| 265 |
+
prob_matrix = np.zeros((len(format_names), n_options), dtype=np.float64)
|
| 266 |
+
|
| 267 |
+
for i, fmt in enumerate(format_names):
|
| 268 |
+
prob_matrix[i] = all_results[fmt]["norm_dists"][metric]
|
| 269 |
+
|
| 270 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 271 |
+
im = ax.imshow(prob_matrix, aspect="auto", vmin=0, vmax=float(np.max(prob_matrix)) if prob_matrix.size else 1.0)
|
| 272 |
+
|
| 273 |
ax.set_xticks(range(n_options))
|
| 274 |
+
ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"])
|
| 275 |
ax.set_yticks(range(len(format_names)))
|
| 276 |
ax.set_yticklabels(format_names, fontsize=9)
|
| 277 |
+
|
| 278 |
+
for i in range(prob_matrix.shape[0]):
|
| 279 |
+
for j in range(prob_matrix.shape[1]):
|
| 280 |
+
ax.text(j, i, f"{prob_matrix[i, j]:.3f}", ha="center", va="center", fontsize=8)
|
| 281 |
+
|
|
|
|
|
|
|
| 282 |
metric_names = {
|
| 283 |
+
"geometric_mean": "mean log-prob softmax",
|
| 284 |
+
"joint_prob": "joint log-prob softmax",
|
| 285 |
+
"first_token_prob": "first-token softmax",
|
| 286 |
+
"avg_prob": "avg-token softmax",
|
| 287 |
}
|
| 288 |
+
|
| 289 |
+
ax.set_title(f"Probability Heatmap (Normalized)\nMetric: {metric_names.get(metric, metric)}", fontsize=12, fontweight="bold")
|
| 290 |
+
plt.colorbar(im, ax=ax, label="Normalized option mass")
|
| 291 |
plt.tight_layout()
|
| 292 |
return fig
|
| 293 |
|
| 294 |
|
| 295 |
def create_metric_comparison_plot(all_results, statement):
|
| 296 |
+
"""
|
| 297 |
+
Compares normalized distributions under four metrics.
|
| 298 |
+
Each subplot: per-format line over option index for the given metric.
|
| 299 |
+
"""
|
| 300 |
+
metrics = ["geometric_mean", "joint_prob", "first_token_prob", "avg_prob"]
|
| 301 |
+
metric_titles = ["Geometric Mean (log) softmax", "Joint (log) softmax", "First-token (log) softmax", "Avg-token (log) softmax"]
|
| 302 |
+
|
| 303 |
+
if not all_results:
|
| 304 |
return None
|
| 305 |
+
|
| 306 |
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 307 |
+
axes = np.array(axes).flatten()
|
| 308 |
+
|
| 309 |
+
for ax, metric, title in zip(axes, metrics, metric_titles):
|
|
|
|
|
|
|
| 310 |
for format_name, data in all_results.items():
|
| 311 |
+
dist = data["norm_dists"][metric]
|
| 312 |
+
ax.plot(range(5), dist, marker="o", label=format_name, alpha=0.75)
|
| 313 |
+
|
| 314 |
+
ax.set_xlabel("Response option index")
|
| 315 |
+
ax.set_ylabel("Normalized option mass")
|
| 316 |
+
ax.set_title(title, fontweight="bold")
|
|
|
|
| 317 |
ax.set_xticks(range(5))
|
| 318 |
+
ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"])
|
|
|
|
| 319 |
ax.grid(True, alpha=0.3)
|
| 320 |
+
ax.legend(fontsize=7, loc="best")
|
| 321 |
+
|
| 322 |
+
plt.suptitle(
|
| 323 |
+
f"Metric Comparison (Normalized Distributions)\nStatement: {statement[:80]}{'...' if len(statement) > 80 else ''}",
|
| 324 |
+
fontsize=12,
|
| 325 |
+
fontweight="bold",
|
| 326 |
+
)
|
| 327 |
plt.tight_layout()
|
| 328 |
return fig
|
| 329 |
|
| 330 |
|
| 331 |
+
# =========================
|
| 332 |
+
# Core analysis
|
| 333 |
+
# =========================
|
| 334 |
+
def analyze_all_formats(statement, persona="", selected_formats=None, metric="geometric_mean"):
|
| 335 |
try:
|
| 336 |
+
default_prompt_template = safe_read_default_prompt()
|
| 337 |
+
|
| 338 |
+
if not statement or not statement.strip():
|
| 339 |
+
return None, None, None, "", "❌ Please enter a statement."
|
| 340 |
+
|
| 341 |
+
if not selected_formats:
|
| 342 |
selected_formats = list(ANSWER_FORMATS.keys())
|
| 343 |
+
|
| 344 |
+
# Build results container
|
| 345 |
all_results = {}
|
| 346 |
detailed_output = []
|
| 347 |
+
|
| 348 |
+
detailed_output.append("=" * 80)
|
| 349 |
+
detailed_output.append("MULTI-TOKEN RAW PROBABILITY ANALYSIS (FIXED)")
|
| 350 |
+
detailed_output.append("=" * 80)
|
| 351 |
+
detailed_output.append("Raw metrics are NOT normalized (true probabilities).")
|
| 352 |
+
detailed_output.append("Plots use a SEPARATE normalized distribution per metric (softmax in log-space).")
|
| 353 |
detailed_output.append("")
|
| 354 |
+
detailed_output.append("Raw metrics:")
|
| 355 |
+
detailed_output.append("- joint_prob: exp(sum log p_i)")
|
| 356 |
+
detailed_output.append("- geometric_mean: exp(mean log p_i) (length-normalized likelihood)")
|
| 357 |
+
detailed_output.append("- perplexity: exp(-mean log p_i) = 1 / geometric_mean")
|
| 358 |
+
detailed_output.append("- first_token_prob: p_1")
|
| 359 |
+
detailed_output.append("- avg_prob: mean(p_i)")
|
| 360 |
+
detailed_output.append("=" * 80)
|
| 361 |
detailed_output.append("")
|
| 362 |
+
|
| 363 |
for format_name in selected_formats:
|
| 364 |
+
cfg = ANSWER_FORMATS[format_name]
|
| 365 |
+
options = cfg["options"]
|
| 366 |
+
labels = cfg["labels"]
|
| 367 |
+
prompt_suffix = cfg["prompt_suffix"]
|
| 368 |
+
|
|
|
|
| 369 |
token_info = get_token_info(options)
|
| 370 |
+
|
|
|
|
| 371 |
full_prompt = default_prompt_template.format(statement=statement.strip())
|
| 372 |
full_prompt += f"\n\n{prompt_suffix}"
|
| 373 |
+
|
|
|
|
| 374 |
messages = []
|
| 375 |
+
if persona and persona.strip():
|
| 376 |
messages.append({"role": "system", "content": persona.strip()})
|
| 377 |
messages.append({"role": "user", "content": full_prompt})
|
| 378 |
+
|
| 379 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 380 |
+
prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(dtype=torch.long)
|
| 381 |
+
|
| 382 |
+
# Compute RAW metrics per option
|
| 383 |
+
raw_metrics = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
for info in token_info:
|
| 385 |
+
m = calculate_sequence_metrics(prompt_ids, info["tokens"])
|
| 386 |
+
raw_metrics.append(m)
|
| 387 |
+
|
| 388 |
+
# Compute normalized distributions for all metrics (for plotting)
|
| 389 |
+
norm_dists = {
|
| 390 |
+
"geometric_mean": normalized_distribution(raw_metrics, metric="geometric_mean", mode="softmax"),
|
| 391 |
+
"joint_prob": normalized_distribution(raw_metrics, metric="joint_prob", mode="softmax"),
|
| 392 |
+
"first_token_prob": normalized_distribution(raw_metrics, metric="first_token_prob", mode="softmax"),
|
| 393 |
+
"avg_prob": normalized_distribution(raw_metrics, metric="avg_prob", mode="softmax"),
|
| 394 |
+
}
|
| 395 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
all_results[format_name] = {
|
| 397 |
+
"labels": labels,
|
| 398 |
+
"options": options,
|
| 399 |
+
"token_info": token_info,
|
| 400 |
+
"raw_metrics": raw_metrics,
|
| 401 |
+
"norm_dists": norm_dists,
|
| 402 |
}
|
| 403 |
+
|
| 404 |
+
# Detailed output (RAW + selected-metric normalized mass)
|
| 405 |
+
detailed_output.append(f"\n{'=' * 80}")
|
| 406 |
detailed_output.append(f"Format: {format_name}")
|
| 407 |
+
detailed_output.append(f"{'=' * 80}")
|
| 408 |
+
|
| 409 |
+
selected_norm = norm_dists[metric]
|
| 410 |
+
|
| 411 |
+
for opt, lab, info, m, nmass in zip(options, labels, token_info, raw_metrics, selected_norm):
|
| 412 |
+
detailed_output.append(f"\n{lab} ({opt}):")
|
| 413 |
+
detailed_output.append(f" Tokens ({info['token_count']}): {info['decoded_tokens']}")
|
| 414 |
+
detailed_output.append(f" RAW joint_prob: {m['joint_prob']:.6e}")
|
| 415 |
+
detailed_output.append(f" RAW geometric_mean: {m['geometric_mean']:.6e}")
|
| 416 |
+
detailed_output.append(f" RAW first_token_prob: {m['first_token_prob']:.6e}")
|
| 417 |
+
detailed_output.append(f" RAW avg_prob: {m['avg_prob']:.6e}")
|
| 418 |
+
detailed_output.append(f" RAW perplexity: {m['perplexity']:.4f}")
|
| 419 |
+
detailed_output.append(f" NORM({metric}) mass: {nmass:.4f}")
|
| 420 |
+
|
| 421 |
+
# Plots (normalized distributions)
|
| 422 |
+
comparison_plot = create_comparison_plot(all_results, statement, metric=metric)
|
| 423 |
+
heatmap_plot = create_heatmap(all_results, metric=metric)
|
| 424 |
metric_comparison = create_metric_comparison_plot(all_results, statement)
|
| 425 |
+
|
| 426 |
+
return comparison_plot, heatmap_plot, metric_comparison, "\n".join(detailed_output), "✅ Analysis complete"
|
| 427 |
+
|
|
|
|
|
|
|
| 428 |
except Exception as e:
|
|
|
|
| 429 |
error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
|
| 430 |
return None, None, None, "", error_msg
|
| 431 |
|
| 432 |
|
| 433 |
+
# =========================
|
| 434 |
+
# Gradio UI
|
| 435 |
+
# =========================
|
| 436 |
+
with gr.Blocks(title="The Unsampled Truth - Multi-Token Analysis (Fixed)") as demo:
|
| 437 |
+
gr.Markdown(
|
| 438 |
+
"""
|
| 439 |
+
# The Unsampled Truth — Multi-Token Probability Analysis (Fixed)
|
| 440 |
+
|
| 441 |
+
This tool computes **RAW** multi-token likelihood metrics per option and plots **normalized** option distributions
|
| 442 |
+
using **softmax in log-space** (so values stay valid and comparable).
|
| 443 |
+
|
| 444 |
+
- RAW metrics: joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity
|
| 445 |
+
- Plots: normalized option mass under the selected metric
|
| 446 |
+
"""
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
with gr.Row():
|
| 450 |
with gr.Column():
|
| 451 |
statement_input = gr.Textbox(
|
| 452 |
+
label="Statement to Analyze",
|
| 453 |
placeholder="e.g., Climate change is a serious threat",
|
| 454 |
+
lines=3,
|
| 455 |
)
|
| 456 |
persona_input = gr.Textbox(
|
| 457 |
+
label="Persona (Optional)",
|
| 458 |
+
placeholder="e.g., You are a tech entrepreneur",
|
| 459 |
+
lines=2,
|
| 460 |
)
|
| 461 |
format_selector = gr.CheckboxGroup(
|
| 462 |
choices=list(ANSWER_FORMATS.keys()),
|
| 463 |
value=list(ANSWER_FORMATS.keys()),
|
| 464 |
label="Select Answer Formats to Compare",
|
| 465 |
+
interactive=True,
|
| 466 |
)
|
| 467 |
metric_selector = gr.Radio(
|
| 468 |
choices=[
|
| 469 |
("Geometric Mean (Recommended)", "geometric_mean"),
|
| 470 |
("Joint Probability", "joint_prob"),
|
| 471 |
("First Token Only", "first_token_prob"),
|
| 472 |
+
("Average Token Probability", "avg_prob"),
|
| 473 |
],
|
| 474 |
value="geometric_mean",
|
| 475 |
+
label="Comparison Metric (for plots + NORM mass line)",
|
|
|
|
| 476 |
)
|
| 477 |
analyze_btn = gr.Button("Analyze All Formats", variant="primary")
|
| 478 |
+
|
| 479 |
with gr.Row():
|
| 480 |
with gr.Column():
|
| 481 |
+
comparison_plot = gr.Plot(label="Format Comparison (Normalized)")
|
| 482 |
with gr.Column():
|
| 483 |
+
heatmap_plot = gr.Plot(label="Heatmap (Normalized)")
|
| 484 |
+
|
| 485 |
with gr.Row():
|
| 486 |
+
metric_comparison = gr.Plot(label="Metric Comparison (Normalized)")
|
| 487 |
+
|
| 488 |
with gr.Row():
|
| 489 |
+
detailed_output = gr.Textbox(label="Detailed Output (RAW metrics + normalized mass)", lines=25)
|
| 490 |
status_output = gr.Textbox(label="Status", lines=2)
|
| 491 |
+
|
|
|
|
| 492 |
gr.Examples(
|
| 493 |
examples=[
|
| 494 |
+
["Climate change is a serious threat", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
|
| 495 |
+
["Immigration has positive economic effects", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
|
| 496 |
+
["Government should provide universal healthcare", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
|
| 497 |
+
["Artificial intelligence will benefit humanity", "You are a tech entrepreneur", list(ANSWER_FORMATS.keys()), "geometric_mean"],
|
| 498 |
+
["Traditional family values are important", "You are a progressive activist", list(ANSWER_FORMATS.keys()), "first_token_prob"],
|
| 499 |
],
|
| 500 |
+
inputs=[statement_input, persona_input, format_selector, metric_selector],
|
| 501 |
)
|
| 502 |
+
|
| 503 |
analyze_btn.click(
|
| 504 |
fn=analyze_all_formats,
|
| 505 |
inputs=[statement_input, persona_input, format_selector, metric_selector],
|
| 506 |
+
outputs=[comparison_plot, heatmap_plot, metric_comparison, detailed_output, status_output],
|
| 507 |
)
|
| 508 |
|
| 509 |
if __name__ == "__main__":
|
| 510 |
+
demo.launch()
|