Chr-Hau's picture
Update app.py
c5738ac verified
import os
import math
import traceback
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
# =========================
# Model init
# =========================
MODEL_NAME = "microsoft/Phi-4-mini-instruct"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
model.eval()
# =========================
# Answer format configs
# =========================
ANSWER_FORMATS = {
"1-5 (numeric)": {
"options": ["1", "2", "3", "4", "5"],
"labels": ["1", "2", "3", "4", "5"],
"prompt_suffix": "Please respond with only a number from 1 to 5."
},
"A-E (uppercase)": {
"options": ["A", "B", "C", "D", "E"],
"labels": ["A", "B", "C", "D", "E"],
"prompt_suffix": "Please respond with only a letter from A to E."
},
"a-e (lowercase)": {
"options": ["a", "b", "c", "d", "e"],
"labels": ["a", "b", "c", "d", "e"],
"prompt_suffix": "Please respond with only a letter from a to e."
},
"Full words": {
"options": ["Strongly Disagree", "Disagree", "Neutral", "Agree", "Strongly Agree"],
"labels": ["SD", "D", "N", "A", "SA"],
"prompt_suffix": "Please respond with one of: Strongly Disagree, Disagree, Neutral, Agree, Strongly Agree."
},
"Full words (lowercase)": {
"options": ["strongly disagree", "disagree", "neutral", "agree", "strongly agree"],
"labels": ["sd", "d", "n", "a", "sa"],
"prompt_suffix": "Please respond with one of: strongly disagree, disagree, neutral, agree, strongly agree."
},
"I-V (Roman numerals)": {
"options": ["I", "II", "III", "IV", "V"],
"labels": ["I", "II", "III", "IV", "V"],
"prompt_suffix": "Please respond with only a Roman numeral from I to V."
},
"Negative to Positive": {
"options": ["-2", "-1", "0", "1", "2"],
"labels": ["-2", "-1", "0", "1", "2"],
"prompt_suffix": "Please respond with only a number from -2 to 2."
},
"Yes/No spectrum": {
"options": ["Definitely No", "Probably No", "Uncertain", "Probably Yes", "Definitely Yes"],
"labels": ["DefNo", "ProbNo", "Unc", "ProbYes", "DefYes"],
"prompt_suffix": "Please respond with one of: Definitely No, Probably No, Uncertain, Probably Yes, Definitely Yes."
},
"Agreement levels": {
"options": ["Completely Disagree", "Somewhat Disagree", "Neither", "Somewhat Agree", "Completely Agree"],
"labels": ["CompD", "SomeD", "Neith", "SomeA", "CompA"],
"prompt_suffix": "Please respond with one of: Completely Disagree, Somewhat Disagree, Neither, Somewhat Agree, Completely Agree."
}
}
# =========================
# Helpers
# =========================
def safe_read_default_prompt(path="default-prompt.txt"):
fallback = (
"You will be given a statement.\n"
"Answer it according to your best judgment.\n\n"
"Statement: {statement}\n"
"Answer:"
)
try:
with open(path, "r", encoding="utf-8") as f:
txt = f.read().strip()
if "{statement}" not in txt:
# ensure it is usable as a format string
return txt + "\n\nStatement: {statement}\nAnswer:"
return txt
except FileNotFoundError:
return fallback
def get_token_info(options):
token_info = []
for i, option in enumerate(options):
tokens = tokenizer.encode(option, add_special_tokens=False)
token_info.append({
"index": i,
"option": option,
"tokens": tokens,
"token_count": len(tokens),
"decoded_tokens": [tokenizer.decode([t]) for t in tokens],
})
return token_info
def calculate_sequence_metrics(prompt_ids: torch.Tensor, option_tokens, temperature=1.0):
"""
Compute RAW sequence metrics in log-space for stability.
Returns RAW:
- joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity
- token_probs, sum_logp, mean_logp, n_tokens
"""
if not option_tokens:
return None
device = model.device
current_input = prompt_ids.to(device)
logps = []
token_probs = []
for tok in option_tokens:
tok = int(tok)
with torch.no_grad():
outputs = model(current_input)
logits = outputs.logits[0, -1, :] / float(temperature)
log_probs = torch.log_softmax(logits, dim=-1)
lp = float(log_probs[tok].item())
p = math.exp(lp)
logps.append(lp)
token_probs.append(p)
next_tok = torch.tensor([[tok]], device=device, dtype=torch.long)
current_input = torch.cat([current_input, next_tok], dim=1)
n = len(option_tokens)
sum_logp = float(np.sum(logps))
mean_logp = sum_logp / n
joint_prob = math.exp(sum_logp) # can be tiny
geometric_mean = math.exp(mean_logp) # in (0, 1]
first_token_prob = token_probs[0]
avg_prob = float(np.mean(token_probs))
perplexity = math.exp(-mean_logp) # = 1 / geometric_mean
return {
"joint_prob": joint_prob,
"geometric_mean": geometric_mean,
"first_token_prob": first_token_prob,
"avg_prob": avg_prob,
"token_probs": token_probs,
"perplexity": perplexity,
"sum_logp": sum_logp,
"mean_logp": mean_logp,
"n_tokens": n,
}
def normalized_distribution(option_metrics, metric="geometric_mean", mode="softmax", eps=1e-12):
"""
Return a normalized distribution over options WITHOUT overwriting raw metrics.
Recommended: mode="softmax" in log-space.
"""
if mode not in ("softmax", "simple"):
raise ValueError("mode must be 'softmax' or 'simple'")
if metric == "joint_prob":
scores = np.array([m["sum_logp"] for m in option_metrics], dtype=np.float64)
elif metric == "geometric_mean":
scores = np.array([m["mean_logp"] for m in option_metrics], dtype=np.float64)
elif metric == "first_token_prob":
scores = np.log(np.array([max(m["first_token_prob"], eps) for m in option_metrics], dtype=np.float64))
elif metric == "avg_prob":
scores = np.log(np.array([max(m["avg_prob"], eps) for m in option_metrics], dtype=np.float64))
else:
raise ValueError(f"Unknown metric: {metric}")
if mode == "simple":
raw = np.array([max(m[metric], eps) for m in option_metrics], dtype=np.float64)
s = raw.sum()
return (raw / s).tolist() if s > 0 else [0.0] * len(option_metrics)
# softmax(scores)
scores = scores - scores.max()
exps = np.exp(scores)
s = exps.sum()
return (exps / s).tolist() if s > 0 else [0.0] * len(option_metrics)
# =========================
# Plotting
# =========================
def create_comparison_plot(all_results, statement, metric="geometric_mean"):
"""Bar plots of normalized option-mass per format for the selected metric."""
n_formats = len(all_results)
if n_formats == 0:
return None
ncols = (n_formats + 1) // 2
fig, axes = plt.subplots(2, ncols, figsize=(16, 8))
axes = np.array(axes).flatten()
metric_names = {
"geometric_mean": "Softmax over mean log-prob (Recommended)",
"joint_prob": "Softmax over joint log-prob",
"first_token_prob": "Softmax over log first-token prob",
"avg_prob": "Softmax over log avg-token prob",
}
for idx, (format_name, data) in enumerate(all_results.items()):
ax = axes[idx]
labels = data["labels"]
dist = data["norm_dists"][metric]
bars = ax.bar(range(len(labels)), dist, alpha=0.85, edgecolor="black")
for bar, p in zip(bars, dist):
height = bar.get_height()
ax.text(
bar.get_x() + bar.get_width() / 2.0,
height + 0.01,
f"{p:.3f}",
ha="center",
va="bottom",
fontsize=8,
)
ax.set_ylabel("Normalized option mass", fontsize=9)
ax.set_title(format_name, fontsize=10, fontweight="bold")
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
ax.set_ylim(0, max(dist) * 1.2 if max(dist) > 0 else 1.0)
ax.grid(True, axis="y", alpha=0.3)
# hide unused subplots
for k in range(n_formats, len(axes)):
axes[k].set_visible(False)
plt.suptitle(
f"Response Distribution Comparison\nMetric: {metric_names.get(metric, metric)}\n"
f"Statement: {statement[:80]}{'...' if len(statement) > 80 else ''}",
fontsize=12,
fontweight="bold",
)
plt.tight_layout()
return fig
def create_heatmap(all_results, metric="geometric_mean"):
"""Heatmap of normalized option-mass per format."""
format_names = list(all_results.keys())
if not format_names:
return None
n_options = 5
prob_matrix = np.zeros((len(format_names), n_options), dtype=np.float64)
for i, fmt in enumerate(format_names):
prob_matrix[i] = all_results[fmt]["norm_dists"][metric]
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(prob_matrix, aspect="auto", vmin=0, vmax=float(np.max(prob_matrix)) if prob_matrix.size else 1.0)
ax.set_xticks(range(n_options))
ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"])
ax.set_yticks(range(len(format_names)))
ax.set_yticklabels(format_names, fontsize=9)
for i in range(prob_matrix.shape[0]):
for j in range(prob_matrix.shape[1]):
ax.text(j, i, f"{prob_matrix[i, j]:.3f}", ha="center", va="center", fontsize=8)
metric_names = {
"geometric_mean": "mean log-prob softmax",
"joint_prob": "joint log-prob softmax",
"first_token_prob": "first-token softmax",
"avg_prob": "avg-token softmax",
}
ax.set_title(f"Probability Heatmap (Normalized)\nMetric: {metric_names.get(metric, metric)}", fontsize=12, fontweight="bold")
plt.colorbar(im, ax=ax, label="Normalized option mass")
plt.tight_layout()
return fig
def create_metric_comparison_plot(all_results, statement):
"""
Compares normalized distributions under four metrics.
Each subplot: per-format line over option index for the given metric.
"""
metrics = ["geometric_mean", "joint_prob", "first_token_prob", "avg_prob"]
metric_titles = ["Geometric Mean (log) softmax", "Joint (log) softmax", "First-token (log) softmax", "Avg-token (log) softmax"]
if not all_results:
return None
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = np.array(axes).flatten()
for ax, metric, title in zip(axes, metrics, metric_titles):
for format_name, data in all_results.items():
dist = data["norm_dists"][metric]
ax.plot(range(5), dist, marker="o", label=format_name, alpha=0.75)
ax.set_xlabel("Response option index")
ax.set_ylabel("Normalized option mass")
ax.set_title(title, fontweight="bold")
ax.set_xticks(range(5))
ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"])
ax.grid(True, alpha=0.3)
ax.legend(fontsize=7, loc="best")
plt.suptitle(
f"Metric Comparison (Normalized Distributions)\nStatement: {statement[:80]}{'...' if len(statement) > 80 else ''}",
fontsize=12,
fontweight="bold",
)
plt.tight_layout()
return fig
# =========================
# Core analysis
# =========================
def analyze_all_formats(statement, persona="", selected_formats=None, metric="geometric_mean"):
try:
default_prompt_template = safe_read_default_prompt()
if not statement or not statement.strip():
return None, None, None, "", "❌ Please enter a statement."
if not selected_formats:
selected_formats = list(ANSWER_FORMATS.keys())
# Build results container
all_results = {}
detailed_output = []
detailed_output.append("=" * 80)
detailed_output.append("MULTI-TOKEN RAW PROBABILITY ANALYSIS (FIXED)")
detailed_output.append("=" * 80)
detailed_output.append("Raw metrics are NOT normalized (true probabilities).")
detailed_output.append("Plots use a SEPARATE normalized distribution per metric (softmax in log-space).")
detailed_output.append("")
detailed_output.append("Raw metrics:")
detailed_output.append("- joint_prob: exp(sum log p_i)")
detailed_output.append("- geometric_mean: exp(mean log p_i) (length-normalized likelihood)")
detailed_output.append("- perplexity: exp(-mean log p_i) = 1 / geometric_mean")
detailed_output.append("- first_token_prob: p_1")
detailed_output.append("- avg_prob: mean(p_i)")
detailed_output.append("=" * 80)
detailed_output.append("")
for format_name in selected_formats:
cfg = ANSWER_FORMATS[format_name]
options = cfg["options"]
labels = cfg["labels"]
prompt_suffix = cfg["prompt_suffix"]
token_info = get_token_info(options)
full_prompt = default_prompt_template.format(statement=statement.strip())
full_prompt += f"\n\n{prompt_suffix}"
messages = []
if persona and persona.strip():
messages.append({"role": "system", "content": persona.strip()})
messages.append({"role": "user", "content": full_prompt})
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(dtype=torch.long)
# Compute RAW metrics per option
raw_metrics = []
for info in token_info:
m = calculate_sequence_metrics(prompt_ids, info["tokens"])
raw_metrics.append(m)
# Compute normalized distributions for all metrics (for plotting)
norm_dists = {
"geometric_mean": normalized_distribution(raw_metrics, metric="geometric_mean", mode="softmax"),
"joint_prob": normalized_distribution(raw_metrics, metric="joint_prob", mode="softmax"),
"first_token_prob": normalized_distribution(raw_metrics, metric="first_token_prob", mode="softmax"),
"avg_prob": normalized_distribution(raw_metrics, metric="avg_prob", mode="softmax"),
}
all_results[format_name] = {
"labels": labels,
"options": options,
"token_info": token_info,
"raw_metrics": raw_metrics,
"norm_dists": norm_dists,
}
# Detailed output (RAW + selected-metric normalized mass)
detailed_output.append(f"\n{'=' * 80}")
detailed_output.append(f"Format: {format_name}")
detailed_output.append(f"{'=' * 80}")
selected_norm = norm_dists[metric]
for opt, lab, info, m, nmass in zip(options, labels, token_info, raw_metrics, selected_norm):
detailed_output.append(f"\n{lab} ({opt}):")
detailed_output.append(f" Tokens ({info['token_count']}): {info['decoded_tokens']}")
detailed_output.append(f" RAW joint_prob: {m['joint_prob']:.6e}")
detailed_output.append(f" RAW geometric_mean: {m['geometric_mean']:.6e}")
detailed_output.append(f" RAW first_token_prob: {m['first_token_prob']:.6e}")
detailed_output.append(f" RAW avg_prob: {m['avg_prob']:.6e}")
detailed_output.append(f" RAW perplexity: {m['perplexity']:.4f}")
detailed_output.append(f" NORM({metric}) mass: {nmass:.4f}")
# Plots (normalized distributions)
comparison_plot = create_comparison_plot(all_results, statement, metric=metric)
heatmap_plot = create_heatmap(all_results, metric=metric)
metric_comparison = create_metric_comparison_plot(all_results, statement)
return comparison_plot, heatmap_plot, metric_comparison, "\n".join(detailed_output), "✅ Analysis complete"
except Exception as e:
error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
return None, None, None, "", error_msg
# =========================
# Gradio UI
# =========================
with gr.Blocks(title="The Unsampled Truth - Multi-Token Analysis (Fixed)") as demo:
gr.Markdown(
"""
# The Unsampled Truth — Multi-Token Probability Analysis (Fixed)
This tool computes **RAW** multi-token likelihood metrics per option and plots **normalized** option distributions
using **softmax in log-space** (so values stay valid and comparable).
- RAW metrics: joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity
- Plots: normalized option mass under the selected metric
"""
)
with gr.Row():
with gr.Column():
statement_input = gr.Textbox(
label="Statement to Analyze",
placeholder="e.g., Climate change is a serious threat",
lines=3,
)
persona_input = gr.Textbox(
label="Persona (Optional)",
placeholder="e.g., You are a tech entrepreneur",
lines=2,
)
format_selector = gr.CheckboxGroup(
choices=list(ANSWER_FORMATS.keys()),
value=list(ANSWER_FORMATS.keys()),
label="Select Answer Formats to Compare",
interactive=True,
)
metric_selector = gr.Radio(
choices=[
("Geometric Mean (Recommended)", "geometric_mean"),
("Joint Probability", "joint_prob"),
("First Token Only", "first_token_prob"),
("Average Token Probability", "avg_prob"),
],
value="geometric_mean",
label="Comparison Metric (for plots + NORM mass line)",
)
analyze_btn = gr.Button("Analyze All Formats", variant="primary")
with gr.Row():
with gr.Column():
comparison_plot = gr.Plot(label="Format Comparison (Normalized)")
with gr.Column():
heatmap_plot = gr.Plot(label="Heatmap (Normalized)")
with gr.Row():
metric_comparison = gr.Plot(label="Metric Comparison (Normalized)")
with gr.Row():
detailed_output = gr.Textbox(label="Detailed Output (RAW metrics + normalized mass)", lines=25)
status_output = gr.Textbox(label="Status", lines=2)
gr.Examples(
examples=[
["Climate change is a serious threat", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
["Immigration has positive economic effects", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
["Government should provide universal healthcare", "", list(ANSWER_FORMATS.keys()), "geometric_mean"],
["Artificial intelligence will benefit humanity", "You are a tech entrepreneur", list(ANSWER_FORMATS.keys()), "geometric_mean"],
["Traditional family values are important", "You are a progressive activist", list(ANSWER_FORMATS.keys()), "first_token_prob"],
],
inputs=[statement_input, persona_input, format_selector, metric_selector],
)
analyze_btn.click(
fn=analyze_all_formats,
inputs=[statement_input, persona_input, format_selector, metric_selector],
outputs=[comparison_plot, heatmap_plot, metric_comparison, detailed_output, status_output],
)
if __name__ == "__main__":
demo.launch()