import os import math import traceback import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import matplotlib.pyplot as plt import numpy as np # ========================= # Model init # ========================= MODEL_NAME = "microsoft/Phi-4-mini-instruct" print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", device_map="auto", trust_remote_code=True, ) model.eval() # ========================= # Answer format configs # ========================= ANSWER_FORMATS = { "1-5 (numeric)": { "options": ["1", "2", "3", "4", "5"], "labels": ["1", "2", "3", "4", "5"], "prompt_suffix": "Please respond with only a number from 1 to 5." }, "A-E (uppercase)": { "options": ["A", "B", "C", "D", "E"], "labels": ["A", "B", "C", "D", "E"], "prompt_suffix": "Please respond with only a letter from A to E." }, "a-e (lowercase)": { "options": ["a", "b", "c", "d", "e"], "labels": ["a", "b", "c", "d", "e"], "prompt_suffix": "Please respond with only a letter from a to e." }, "Full words": { "options": ["Strongly Disagree", "Disagree", "Neutral", "Agree", "Strongly Agree"], "labels": ["SD", "D", "N", "A", "SA"], "prompt_suffix": "Please respond with one of: Strongly Disagree, Disagree, Neutral, Agree, Strongly Agree." }, "Full words (lowercase)": { "options": ["strongly disagree", "disagree", "neutral", "agree", "strongly agree"], "labels": ["sd", "d", "n", "a", "sa"], "prompt_suffix": "Please respond with one of: strongly disagree, disagree, neutral, agree, strongly agree." }, "I-V (Roman numerals)": { "options": ["I", "II", "III", "IV", "V"], "labels": ["I", "II", "III", "IV", "V"], "prompt_suffix": "Please respond with only a Roman numeral from I to V." }, "Negative to Positive": { "options": ["-2", "-1", "0", "1", "2"], "labels": ["-2", "-1", "0", "1", "2"], "prompt_suffix": "Please respond with only a number from -2 to 2." }, "Yes/No spectrum": { "options": ["Definitely No", "Probably No", "Uncertain", "Probably Yes", "Definitely Yes"], "labels": ["DefNo", "ProbNo", "Unc", "ProbYes", "DefYes"], "prompt_suffix": "Please respond with one of: Definitely No, Probably No, Uncertain, Probably Yes, Definitely Yes." }, "Agreement levels": { "options": ["Completely Disagree", "Somewhat Disagree", "Neither", "Somewhat Agree", "Completely Agree"], "labels": ["CompD", "SomeD", "Neith", "SomeA", "CompA"], "prompt_suffix": "Please respond with one of: Completely Disagree, Somewhat Disagree, Neither, Somewhat Agree, Completely Agree." } } # ========================= # Helpers # ========================= def safe_read_default_prompt(path="default-prompt.txt"): fallback = ( "You will be given a statement.\n" "Answer it according to your best judgment.\n\n" "Statement: {statement}\n" "Answer:" ) try: with open(path, "r", encoding="utf-8") as f: txt = f.read().strip() if "{statement}" not in txt: # ensure it is usable as a format string return txt + "\n\nStatement: {statement}\nAnswer:" return txt except FileNotFoundError: return fallback def get_token_info(options): token_info = [] for i, option in enumerate(options): tokens = tokenizer.encode(option, add_special_tokens=False) token_info.append({ "index": i, "option": option, "tokens": tokens, "token_count": len(tokens), "decoded_tokens": [tokenizer.decode([t]) for t in tokens], }) return token_info def calculate_sequence_metrics(prompt_ids: torch.Tensor, option_tokens, temperature=1.0): """ Compute RAW sequence metrics in log-space for stability. Returns RAW: - joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity - token_probs, sum_logp, mean_logp, n_tokens """ if not option_tokens: return None device = model.device current_input = prompt_ids.to(device) logps = [] token_probs = [] for tok in option_tokens: tok = int(tok) with torch.no_grad(): outputs = model(current_input) logits = outputs.logits[0, -1, :] / float(temperature) log_probs = torch.log_softmax(logits, dim=-1) lp = float(log_probs[tok].item()) p = math.exp(lp) logps.append(lp) token_probs.append(p) next_tok = torch.tensor([[tok]], device=device, dtype=torch.long) current_input = torch.cat([current_input, next_tok], dim=1) n = len(option_tokens) sum_logp = float(np.sum(logps)) mean_logp = sum_logp / n joint_prob = math.exp(sum_logp) # can be tiny geometric_mean = math.exp(mean_logp) # in (0, 1] first_token_prob = token_probs[0] avg_prob = float(np.mean(token_probs)) perplexity = math.exp(-mean_logp) # = 1 / geometric_mean return { "joint_prob": joint_prob, "geometric_mean": geometric_mean, "first_token_prob": first_token_prob, "avg_prob": avg_prob, "token_probs": token_probs, "perplexity": perplexity, "sum_logp": sum_logp, "mean_logp": mean_logp, "n_tokens": n, } def normalized_distribution(option_metrics, metric="geometric_mean", mode="softmax", eps=1e-12): """ Return a normalized distribution over options WITHOUT overwriting raw metrics. Recommended: mode="softmax" in log-space. """ if mode not in ("softmax", "simple"): raise ValueError("mode must be 'softmax' or 'simple'") if metric == "joint_prob": scores = np.array([m["sum_logp"] for m in option_metrics], dtype=np.float64) elif metric == "geometric_mean": scores = np.array([m["mean_logp"] for m in option_metrics], dtype=np.float64) elif metric == "first_token_prob": scores = np.log(np.array([max(m["first_token_prob"], eps) for m in option_metrics], dtype=np.float64)) elif metric == "avg_prob": scores = np.log(np.array([max(m["avg_prob"], eps) for m in option_metrics], dtype=np.float64)) else: raise ValueError(f"Unknown metric: {metric}") if mode == "simple": raw = np.array([max(m[metric], eps) for m in option_metrics], dtype=np.float64) s = raw.sum() return (raw / s).tolist() if s > 0 else [0.0] * len(option_metrics) # softmax(scores) scores = scores - scores.max() exps = np.exp(scores) s = exps.sum() return (exps / s).tolist() if s > 0 else [0.0] * len(option_metrics) # ========================= # Plotting # ========================= def create_comparison_plot(all_results, statement, metric="geometric_mean"): """Bar plots of normalized option-mass per format for the selected metric.""" n_formats = len(all_results) if n_formats == 0: return None ncols = (n_formats + 1) // 2 fig, axes = plt.subplots(2, ncols, figsize=(16, 8)) axes = np.array(axes).flatten() metric_names = { "geometric_mean": "Softmax over mean log-prob (Recommended)", "joint_prob": "Softmax over joint log-prob", "first_token_prob": "Softmax over log first-token prob", "avg_prob": "Softmax over log avg-token prob", } for idx, (format_name, data) in enumerate(all_results.items()): ax = axes[idx] labels = data["labels"] dist = data["norm_dists"][metric] bars = ax.bar(range(len(labels)), dist, alpha=0.85, edgecolor="black") for bar, p in zip(bars, dist): height = bar.get_height() ax.text( bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{p:.3f}", ha="center", va="bottom", fontsize=8, ) ax.set_ylabel("Normalized option mass", fontsize=9) ax.set_title(format_name, fontsize=10, fontweight="bold") ax.set_xticks(range(len(labels))) ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) ax.set_ylim(0, max(dist) * 1.2 if max(dist) > 0 else 1.0) ax.grid(True, axis="y", alpha=0.3) # hide unused subplots for k in range(n_formats, len(axes)): axes[k].set_visible(False) plt.suptitle( f"Response Distribution Comparison\nMetric: {metric_names.get(metric, metric)}\n" f"Statement: {statement[:80]}{'...' if len(statement) > 80 else ''}", fontsize=12, fontweight="bold", ) plt.tight_layout() return fig def create_heatmap(all_results, metric="geometric_mean"): """Heatmap of normalized option-mass per format.""" format_names = list(all_results.keys()) if not format_names: return None n_options = 5 prob_matrix = np.zeros((len(format_names), n_options), dtype=np.float64) for i, fmt in enumerate(format_names): prob_matrix[i] = all_results[fmt]["norm_dists"][metric] fig, ax = plt.subplots(figsize=(10, 8)) im = ax.imshow(prob_matrix, aspect="auto", vmin=0, vmax=float(np.max(prob_matrix)) if prob_matrix.size else 1.0) ax.set_xticks(range(n_options)) ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"]) ax.set_yticks(range(len(format_names))) ax.set_yticklabels(format_names, fontsize=9) for i in range(prob_matrix.shape[0]): for j in range(prob_matrix.shape[1]): ax.text(j, i, f"{prob_matrix[i, j]:.3f}", ha="center", va="center", fontsize=8) metric_names = { "geometric_mean": "mean log-prob softmax", "joint_prob": "joint log-prob softmax", "first_token_prob": "first-token softmax", "avg_prob": "avg-token softmax", } ax.set_title(f"Probability Heatmap (Normalized)\nMetric: {metric_names.get(metric, metric)}", fontsize=12, fontweight="bold") plt.colorbar(im, ax=ax, label="Normalized option mass") plt.tight_layout() return fig def create_metric_comparison_plot(all_results, statement): """ Compares normalized distributions under four metrics. Each subplot: per-format line over option index for the given metric. """ metrics = ["geometric_mean", "joint_prob", "first_token_prob", "avg_prob"] metric_titles = ["Geometric Mean (log) softmax", "Joint (log) softmax", "First-token (log) softmax", "Avg-token (log) softmax"] if not all_results: return None fig, axes = plt.subplots(2, 2, figsize=(14, 10)) axes = np.array(axes).flatten() for ax, metric, title in zip(axes, metrics, metric_titles): for format_name, data in all_results.items(): dist = data["norm_dists"][metric] ax.plot(range(5), dist, marker="o", label=format_name, alpha=0.75) ax.set_xlabel("Response option index") ax.set_ylabel("Normalized option mass") ax.set_title(title, fontweight="bold") ax.set_xticks(range(5)) ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"]) ax.grid(True, alpha=0.3) ax.legend(fontsize=7, loc="best") plt.suptitle( f"Metric Comparison (Normalized Distributions)\nStatement: {statement[:80]}{'...' if len(statement) > 80 else ''}", fontsize=12, fontweight="bold", ) plt.tight_layout() return fig # ========================= # Core analysis # ========================= def analyze_all_formats(statement, persona="", selected_formats=None, metric="geometric_mean"): try: default_prompt_template = safe_read_default_prompt() if not statement or not statement.strip(): return None, None, None, "", "❌ Please enter a statement." if not selected_formats: selected_formats = list(ANSWER_FORMATS.keys()) # Build results container all_results = {} detailed_output = [] detailed_output.append("=" * 80) detailed_output.append("MULTI-TOKEN RAW PROBABILITY ANALYSIS (FIXED)") detailed_output.append("=" * 80) detailed_output.append("Raw metrics are NOT normalized (true probabilities).") detailed_output.append("Plots use a SEPARATE normalized distribution per metric (softmax in log-space).") detailed_output.append("") detailed_output.append("Raw metrics:") detailed_output.append("- joint_prob: exp(sum log p_i)") detailed_output.append("- geometric_mean: exp(mean log p_i) (length-normalized likelihood)") detailed_output.append("- perplexity: exp(-mean log p_i) = 1 / geometric_mean") detailed_output.append("- first_token_prob: p_1") detailed_output.append("- avg_prob: mean(p_i)") detailed_output.append("=" * 80) detailed_output.append("") for format_name in selected_formats: cfg = ANSWER_FORMATS[format_name] options = cfg["options"] labels = cfg["labels"] prompt_suffix = cfg["prompt_suffix"] token_info = get_token_info(options) full_prompt = default_prompt_template.format(statement=statement.strip()) full_prompt += f"\n\n{prompt_suffix}" messages = [] if persona and persona.strip(): messages.append({"role": "system", "content": persona.strip()}) messages.append({"role": "user", "content": full_prompt}) prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(dtype=torch.long) # Compute RAW metrics per option raw_metrics = [] for info in token_info: m = calculate_sequence_metrics(prompt_ids, info["tokens"]) raw_metrics.append(m) # Compute normalized distributions for all metrics (for plotting) norm_dists = { "geometric_mean": normalized_distribution(raw_metrics, metric="geometric_mean", mode="softmax"), "joint_prob": normalized_distribution(raw_metrics, metric="joint_prob", mode="softmax"), "first_token_prob": normalized_distribution(raw_metrics, metric="first_token_prob", mode="softmax"), "avg_prob": normalized_distribution(raw_metrics, metric="avg_prob", mode="softmax"), } all_results[format_name] = { "labels": labels, "options": options, "token_info": token_info, "raw_metrics": raw_metrics, "norm_dists": norm_dists, } # Detailed output (RAW + selected-metric normalized mass) detailed_output.append(f"\n{'=' * 80}") detailed_output.append(f"Format: {format_name}") detailed_output.append(f"{'=' * 80}") selected_norm = norm_dists[metric] for opt, lab, info, m, nmass in zip(options, labels, token_info, raw_metrics, selected_norm): detailed_output.append(f"\n{lab} ({opt}):") detailed_output.append(f" Tokens ({info['token_count']}): {info['decoded_tokens']}") detailed_output.append(f" RAW joint_prob: {m['joint_prob']:.6e}") detailed_output.append(f" RAW geometric_mean: {m['geometric_mean']:.6e}") detailed_output.append(f" RAW first_token_prob: {m['first_token_prob']:.6e}") detailed_output.append(f" RAW avg_prob: {m['avg_prob']:.6e}") detailed_output.append(f" RAW perplexity: {m['perplexity']:.4f}") detailed_output.append(f" NORM({metric}) mass: {nmass:.4f}") # Plots (normalized distributions) comparison_plot = create_comparison_plot(all_results, statement, metric=metric) heatmap_plot = create_heatmap(all_results, metric=metric) metric_comparison = create_metric_comparison_plot(all_results, statement) return comparison_plot, heatmap_plot, metric_comparison, "\n".join(detailed_output), "✅ Analysis complete" except Exception as e: error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}" return None, None, None, "", error_msg # ========================= # Gradio UI # ========================= with gr.Blocks(title="The Unsampled Truth - Multi-Token Analysis (Fixed)") as demo: gr.Markdown( """ # The Unsampled Truth — Multi-Token Probability Analysis (Fixed) This tool computes **RAW** multi-token likelihood metrics per option and plots **normalized** option distributions using **softmax in log-space** (so values stay valid and comparable). - RAW metrics: joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity - Plots: normalized option mass under the selected metric """ ) with gr.Row(): with gr.Column(): statement_input = gr.Textbox( label="Statement to Analyze", placeholder="e.g., Climate change is a serious threat", lines=3, ) persona_input = gr.Textbox( label="Persona (Optional)", placeholder="e.g., You are a tech entrepreneur", lines=2, ) format_selector = gr.CheckboxGroup( choices=list(ANSWER_FORMATS.keys()), value=list(ANSWER_FORMATS.keys()), label="Select Answer Formats to Compare", interactive=True, ) metric_selector = gr.Radio( choices=[ ("Geometric Mean (Recommended)", "geometric_mean"), ("Joint Probability", "joint_prob"), ("First Token Only", "first_token_prob"), ("Average Token Probability", "avg_prob"), ], value="geometric_mean", label="Comparison Metric (for plots + NORM mass line)", ) analyze_btn = gr.Button("Analyze All Formats", variant="primary") with gr.Row(): with gr.Column(): comparison_plot = gr.Plot(label="Format Comparison (Normalized)") with gr.Column(): heatmap_plot = gr.Plot(label="Heatmap (Normalized)") with gr.Row(): metric_comparison = gr.Plot(label="Metric Comparison (Normalized)") with gr.Row(): detailed_output = gr.Textbox(label="Detailed Output (RAW metrics + normalized mass)", lines=25) status_output = gr.Textbox(label="Status", lines=2) gr.Examples( examples=[ ["Climate change is a serious threat", "", list(ANSWER_FORMATS.keys()), "geometric_mean"], ["Immigration has positive economic effects", "", list(ANSWER_FORMATS.keys()), "geometric_mean"], ["Government should provide universal healthcare", "", list(ANSWER_FORMATS.keys()), "geometric_mean"], ["Artificial intelligence will benefit humanity", "You are a tech entrepreneur", list(ANSWER_FORMATS.keys()), "geometric_mean"], ["Traditional family values are important", "You are a progressive activist", list(ANSWER_FORMATS.keys()), "first_token_prob"], ], inputs=[statement_input, persona_input, format_selector, metric_selector], ) analyze_btn.click( fn=analyze_all_formats, inputs=[statement_input, persona_input, format_selector, metric_selector], outputs=[comparison_plot, heatmap_plot, metric_comparison, detailed_output, status_output], ) if __name__ == "__main__": demo.launch()