Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import traceback | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # ========================= | |
| # Model init | |
| # ========================= | |
| MODEL_NAME = "microsoft/Phi-4-mini-instruct" | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| # ========================= | |
| # Answer format configs | |
| # ========================= | |
| ANSWER_FORMATS = { | |
| "1-5 (numeric)": { | |
| "options": ["1", "2", "3", "4", "5"], | |
| "labels": ["1", "2", "3", "4", "5"], | |
| "prompt_suffix": "Please respond with only a number from 1 to 5." | |
| }, | |
| "A-E (uppercase)": { | |
| "options": ["A", "B", "C", "D", "E"], | |
| "labels": ["A", "B", "C", "D", "E"], | |
| "prompt_suffix": "Please respond with only a letter from A to E." | |
| }, | |
| "a-e (lowercase)": { | |
| "options": ["a", "b", "c", "d", "e"], | |
| "labels": ["a", "b", "c", "d", "e"], | |
| "prompt_suffix": "Please respond with only a letter from a to e." | |
| }, | |
| "Full words": { | |
| "options": ["Strongly Disagree", "Disagree", "Neutral", "Agree", "Strongly Agree"], | |
| "labels": ["SD", "D", "N", "A", "SA"], | |
| "prompt_suffix": "Please respond with one of: Strongly Disagree, Disagree, Neutral, Agree, Strongly Agree." | |
| }, | |
| "Full words (lowercase)": { | |
| "options": ["strongly disagree", "disagree", "neutral", "agree", "strongly agree"], | |
| "labels": ["sd", "d", "n", "a", "sa"], | |
| "prompt_suffix": "Please respond with one of: strongly disagree, disagree, neutral, agree, strongly agree." | |
| }, | |
| "I-V (Roman numerals)": { | |
| "options": ["I", "II", "III", "IV", "V"], | |
| "labels": ["I", "II", "III", "IV", "V"], | |
| "prompt_suffix": "Please respond with only a Roman numeral from I to V." | |
| }, | |
| "Negative to Positive": { | |
| "options": ["-2", "-1", "0", "1", "2"], | |
| "labels": ["-2", "-1", "0", "1", "2"], | |
| "prompt_suffix": "Please respond with only a number from -2 to 2." | |
| }, | |
| "Yes/No spectrum": { | |
| "options": ["Definitely No", "Probably No", "Uncertain", "Probably Yes", "Definitely Yes"], | |
| "labels": ["DefNo", "ProbNo", "Unc", "ProbYes", "DefYes"], | |
| "prompt_suffix": "Please respond with one of: Definitely No, Probably No, Uncertain, Probably Yes, Definitely Yes." | |
| }, | |
| "Agreement levels": { | |
| "options": ["Completely Disagree", "Somewhat Disagree", "Neither", "Somewhat Agree", "Completely Agree"], | |
| "labels": ["CompD", "SomeD", "Neith", "SomeA", "CompA"], | |
| "prompt_suffix": "Please respond with one of: Completely Disagree, Somewhat Disagree, Neither, Somewhat Agree, Completely Agree." | |
| } | |
| } | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def safe_read_default_prompt(path="default-prompt.txt"): | |
| fallback = ( | |
| "You will be given a statement.\n" | |
| "Answer it according to your best judgment.\n\n" | |
| "Statement: {statement}\n" | |
| "Answer:" | |
| ) | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| txt = f.read().strip() | |
| if "{statement}" not in txt: | |
| # ensure it is usable as a format string | |
| return txt + "\n\nStatement: {statement}\nAnswer:" | |
| return txt | |
| except FileNotFoundError: | |
| return fallback | |
| def get_token_info(options): | |
| token_info = [] | |
| for i, option in enumerate(options): | |
| tokens = tokenizer.encode(option, add_special_tokens=False) | |
| token_info.append({ | |
| "index": i, | |
| "option": option, | |
| "tokens": tokens, | |
| "token_count": len(tokens), | |
| "decoded_tokens": [tokenizer.decode([t]) for t in tokens], | |
| }) | |
| return token_info | |
| def calculate_sequence_metrics(prompt_ids: torch.Tensor, option_tokens, temperature=1.0): | |
| """ | |
| Compute RAW sequence metrics in log-space for stability. | |
| Returns RAW: | |
| - joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity | |
| - token_probs, sum_logp, mean_logp, n_tokens | |
| """ | |
| if not option_tokens: | |
| return None | |
| device = model.device | |
| current_input = prompt_ids.to(device) | |
| logps = [] | |
| token_probs = [] | |
| for tok in option_tokens: | |
| tok = int(tok) | |
| with torch.no_grad(): | |
| outputs = model(current_input) | |
| logits = outputs.logits[0, -1, :] / float(temperature) | |
| log_probs = torch.log_softmax(logits, dim=-1) | |
| lp = float(log_probs[tok].item()) | |
| p = math.exp(lp) | |
| logps.append(lp) | |
| token_probs.append(p) | |
| next_tok = torch.tensor([[tok]], device=device, dtype=torch.long) | |
| current_input = torch.cat([current_input, next_tok], dim=1) | |
| n = len(option_tokens) | |
| sum_logp = float(np.sum(logps)) | |
| mean_logp = sum_logp / n | |
| joint_prob = math.exp(sum_logp) # can be tiny | |
| geometric_mean = math.exp(mean_logp) # in (0, 1] | |
| first_token_prob = token_probs[0] | |
| avg_prob = float(np.mean(token_probs)) | |
| perplexity = math.exp(-mean_logp) # = 1 / geometric_mean | |
| return { | |
| "joint_prob": joint_prob, | |
| "geometric_mean": geometric_mean, | |
| "first_token_prob": first_token_prob, | |
| "avg_prob": avg_prob, | |
| "token_probs": token_probs, | |
| "perplexity": perplexity, | |
| "sum_logp": sum_logp, | |
| "mean_logp": mean_logp, | |
| "n_tokens": n, | |
| } | |
| def normalized_distribution(option_metrics, metric="geometric_mean", mode="softmax", eps=1e-12): | |
| """ | |
| Return a normalized distribution over options WITHOUT overwriting raw metrics. | |
| Recommended: mode="softmax" in log-space. | |
| """ | |
| if mode not in ("softmax", "simple"): | |
| raise ValueError("mode must be 'softmax' or 'simple'") | |
| if metric == "joint_prob": | |
| scores = np.array([m["sum_logp"] for m in option_metrics], dtype=np.float64) | |
| elif metric == "geometric_mean": | |
| scores = np.array([m["mean_logp"] for m in option_metrics], dtype=np.float64) | |
| elif metric == "first_token_prob": | |
| scores = np.log(np.array([max(m["first_token_prob"], eps) for m in option_metrics], dtype=np.float64)) | |
| elif metric == "avg_prob": | |
| scores = np.log(np.array([max(m["avg_prob"], eps) for m in option_metrics], dtype=np.float64)) | |
| else: | |
| raise ValueError(f"Unknown metric: {metric}") | |
| if mode == "simple": | |
| raw = np.array([max(m[metric], eps) for m in option_metrics], dtype=np.float64) | |
| s = raw.sum() | |
| return (raw / s).tolist() if s > 0 else [0.0] * len(option_metrics) | |
| # softmax(scores) | |
| scores = scores - scores.max() | |
| exps = np.exp(scores) | |
| s = exps.sum() | |
| return (exps / s).tolist() if s > 0 else [0.0] * len(option_metrics) | |
| # ========================= | |
| # Plotting | |
| # ========================= | |
| def create_comparison_plot(all_results, statement, metric="geometric_mean"): | |
| """Bar plots of normalized option-mass per format for the selected metric.""" | |
| n_formats = len(all_results) | |
| if n_formats == 0: | |
| return None | |
| ncols = (n_formats + 1) // 2 | |
| fig, axes = plt.subplots(2, ncols, figsize=(16, 8)) | |
| axes = np.array(axes).flatten() | |
| metric_names = { | |
| "geometric_mean": "Softmax over mean log-prob (Recommended)", | |
| "joint_prob": "Softmax over joint log-prob", | |
| "first_token_prob": "Softmax over log first-token prob", | |
| "avg_prob": "Softmax over log avg-token prob", | |
| } | |
| for idx, (format_name, data) in enumerate(all_results.items()): | |
| ax = axes[idx] | |
| labels = data["labels"] | |
| dist = data["norm_dists"][metric] | |
| bars = ax.bar(range(len(labels)), dist, alpha=0.85, edgecolor="black") | |
| for bar, p in zip(bars, dist): | |
| height = bar.get_height() | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2.0, | |
| height + 0.01, | |
| f"{p:.3f}", | |
| ha="center", | |
| va="bottom", | |
| fontsize=8, | |
| ) | |
| ax.set_ylabel("Normalized option mass", fontsize=9) | |
| ax.set_title(format_name, fontsize=10, fontweight="bold") | |
| ax.set_xticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) | |
| ax.set_ylim(0, max(dist) * 1.2 if max(dist) > 0 else 1.0) | |
| ax.grid(True, axis="y", alpha=0.3) | |
| # hide unused subplots | |
| for k in range(n_formats, len(axes)): | |
| axes[k].set_visible(False) | |
| plt.suptitle( | |
| f"Response Distribution Comparison\nMetric: {metric_names.get(metric, metric)}\n" | |
| f"Statement: {statement[:80]}{'...' if len(statement) > 80 else ''}", | |
| fontsize=12, | |
| fontweight="bold", | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| def create_heatmap(all_results, metric="geometric_mean"): | |
| """Heatmap of normalized option-mass per format.""" | |
| format_names = list(all_results.keys()) | |
| if not format_names: | |
| return None | |
| n_options = 5 | |
| prob_matrix = np.zeros((len(format_names), n_options), dtype=np.float64) | |
| for i, fmt in enumerate(format_names): | |
| prob_matrix[i] = all_results[fmt]["norm_dists"][metric] | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| im = ax.imshow(prob_matrix, aspect="auto", vmin=0, vmax=float(np.max(prob_matrix)) if prob_matrix.size else 1.0) | |
| ax.set_xticks(range(n_options)) | |
| ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"]) | |
| ax.set_yticks(range(len(format_names))) | |
| ax.set_yticklabels(format_names, fontsize=9) | |
| for i in range(prob_matrix.shape[0]): | |
| for j in range(prob_matrix.shape[1]): | |
| ax.text(j, i, f"{prob_matrix[i, j]:.3f}", ha="center", va="center", fontsize=8) | |
| metric_names = { | |
| "geometric_mean": "mean log-prob softmax", | |
| "joint_prob": "joint log-prob softmax", | |
| "first_token_prob": "first-token softmax", | |
| "avg_prob": "avg-token softmax", | |
| } | |
| ax.set_title(f"Probability Heatmap (Normalized)\nMetric: {metric_names.get(metric, metric)}", fontsize=12, fontweight="bold") | |
| plt.colorbar(im, ax=ax, label="Normalized option mass") | |
| plt.tight_layout() | |
| return fig | |
| def create_metric_comparison_plot(all_results, statement): | |
| """ | |
| Compares normalized distributions under four metrics. | |
| Each subplot: per-format line over option index for the given metric. | |
| """ | |
| metrics = ["geometric_mean", "joint_prob", "first_token_prob", "avg_prob"] | |
| metric_titles = ["Geometric Mean (log) softmax", "Joint (log) softmax", "First-token (log) softmax", "Avg-token (log) softmax"] | |
| if not all_results: | |
| return None | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) | |
| axes = np.array(axes).flatten() | |
| for ax, metric, title in zip(axes, metrics, metric_titles): | |
| for format_name, data in all_results.items(): | |
| dist = data["norm_dists"][metric] | |
| ax.plot(range(5), dist, marker="o", label=format_name, alpha=0.75) | |
| ax.set_xlabel("Response option index") | |
| ax.set_ylabel("Normalized option mass") | |
| ax.set_title(title, fontweight="bold") | |
| ax.set_xticks(range(5)) | |
| ax.set_xticklabels(["Opt 1", "Opt 2", "Opt 3", "Opt 4", "Opt 5"]) | |
| ax.grid(True, alpha=0.3) | |
| ax.legend(fontsize=7, loc="best") | |
| plt.suptitle( | |
| f"Metric Comparison (Normalized Distributions)\nStatement: {statement[:80]}{'...' if len(statement) > 80 else ''}", | |
| fontsize=12, | |
| fontweight="bold", | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| # ========================= | |
| # Core analysis | |
| # ========================= | |
| def analyze_all_formats(statement, persona="", selected_formats=None, metric="geometric_mean"): | |
| try: | |
| default_prompt_template = safe_read_default_prompt() | |
| if not statement or not statement.strip(): | |
| return None, None, None, "", "❌ Please enter a statement." | |
| if not selected_formats: | |
| selected_formats = list(ANSWER_FORMATS.keys()) | |
| # Build results container | |
| all_results = {} | |
| detailed_output = [] | |
| detailed_output.append("=" * 80) | |
| detailed_output.append("MULTI-TOKEN RAW PROBABILITY ANALYSIS (FIXED)") | |
| detailed_output.append("=" * 80) | |
| detailed_output.append("Raw metrics are NOT normalized (true probabilities).") | |
| detailed_output.append("Plots use a SEPARATE normalized distribution per metric (softmax in log-space).") | |
| detailed_output.append("") | |
| detailed_output.append("Raw metrics:") | |
| detailed_output.append("- joint_prob: exp(sum log p_i)") | |
| detailed_output.append("- geometric_mean: exp(mean log p_i) (length-normalized likelihood)") | |
| detailed_output.append("- perplexity: exp(-mean log p_i) = 1 / geometric_mean") | |
| detailed_output.append("- first_token_prob: p_1") | |
| detailed_output.append("- avg_prob: mean(p_i)") | |
| detailed_output.append("=" * 80) | |
| detailed_output.append("") | |
| for format_name in selected_formats: | |
| cfg = ANSWER_FORMATS[format_name] | |
| options = cfg["options"] | |
| labels = cfg["labels"] | |
| prompt_suffix = cfg["prompt_suffix"] | |
| token_info = get_token_info(options) | |
| full_prompt = default_prompt_template.format(statement=statement.strip()) | |
| full_prompt += f"\n\n{prompt_suffix}" | |
| messages = [] | |
| if persona and persona.strip(): | |
| messages.append({"role": "system", "content": persona.strip()}) | |
| messages.append({"role": "user", "content": full_prompt}) | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(dtype=torch.long) | |
| # Compute RAW metrics per option | |
| raw_metrics = [] | |
| for info in token_info: | |
| m = calculate_sequence_metrics(prompt_ids, info["tokens"]) | |
| raw_metrics.append(m) | |
| # Compute normalized distributions for all metrics (for plotting) | |
| norm_dists = { | |
| "geometric_mean": normalized_distribution(raw_metrics, metric="geometric_mean", mode="softmax"), | |
| "joint_prob": normalized_distribution(raw_metrics, metric="joint_prob", mode="softmax"), | |
| "first_token_prob": normalized_distribution(raw_metrics, metric="first_token_prob", mode="softmax"), | |
| "avg_prob": normalized_distribution(raw_metrics, metric="avg_prob", mode="softmax"), | |
| } | |
| all_results[format_name] = { | |
| "labels": labels, | |
| "options": options, | |
| "token_info": token_info, | |
| "raw_metrics": raw_metrics, | |
| "norm_dists": norm_dists, | |
| } | |
| # Detailed output (RAW + selected-metric normalized mass) | |
| detailed_output.append(f"\n{'=' * 80}") | |
| detailed_output.append(f"Format: {format_name}") | |
| detailed_output.append(f"{'=' * 80}") | |
| selected_norm = norm_dists[metric] | |
| for opt, lab, info, m, nmass in zip(options, labels, token_info, raw_metrics, selected_norm): | |
| detailed_output.append(f"\n{lab} ({opt}):") | |
| detailed_output.append(f" Tokens ({info['token_count']}): {info['decoded_tokens']}") | |
| detailed_output.append(f" RAW joint_prob: {m['joint_prob']:.6e}") | |
| detailed_output.append(f" RAW geometric_mean: {m['geometric_mean']:.6e}") | |
| detailed_output.append(f" RAW first_token_prob: {m['first_token_prob']:.6e}") | |
| detailed_output.append(f" RAW avg_prob: {m['avg_prob']:.6e}") | |
| detailed_output.append(f" RAW perplexity: {m['perplexity']:.4f}") | |
| detailed_output.append(f" NORM({metric}) mass: {nmass:.4f}") | |
| # Plots (normalized distributions) | |
| comparison_plot = create_comparison_plot(all_results, statement, metric=metric) | |
| heatmap_plot = create_heatmap(all_results, metric=metric) | |
| metric_comparison = create_metric_comparison_plot(all_results, statement) | |
| return comparison_plot, heatmap_plot, metric_comparison, "\n".join(detailed_output), "✅ Analysis complete" | |
| except Exception as e: | |
| error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}" | |
| return None, None, None, "", error_msg | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| with gr.Blocks(title="The Unsampled Truth - Multi-Token Analysis (Fixed)") as demo: | |
| gr.Markdown( | |
| """ | |
| # The Unsampled Truth — Multi-Token Probability Analysis (Fixed) | |
| This tool computes **RAW** multi-token likelihood metrics per option and plots **normalized** option distributions | |
| using **softmax in log-space** (so values stay valid and comparable). | |
| - RAW metrics: joint_prob, geometric_mean, first_token_prob, avg_prob, perplexity | |
| - Plots: normalized option mass under the selected metric | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| statement_input = gr.Textbox( | |
| label="Statement to Analyze", | |
| placeholder="e.g., Climate change is a serious threat", | |
| lines=3, | |
| ) | |
| persona_input = gr.Textbox( | |
| label="Persona (Optional)", | |
| placeholder="e.g., You are a tech entrepreneur", | |
| lines=2, | |
| ) | |
| format_selector = gr.CheckboxGroup( | |
| choices=list(ANSWER_FORMATS.keys()), | |
| value=list(ANSWER_FORMATS.keys()), | |
| label="Select Answer Formats to Compare", | |
| interactive=True, | |
| ) | |
| metric_selector = gr.Radio( | |
| choices=[ | |
| ("Geometric Mean (Recommended)", "geometric_mean"), | |
| ("Joint Probability", "joint_prob"), | |
| ("First Token Only", "first_token_prob"), | |
| ("Average Token Probability", "avg_prob"), | |
| ], | |
| value="geometric_mean", | |
| label="Comparison Metric (for plots + NORM mass line)", | |
| ) | |
| analyze_btn = gr.Button("Analyze All Formats", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| comparison_plot = gr.Plot(label="Format Comparison (Normalized)") | |
| with gr.Column(): | |
| heatmap_plot = gr.Plot(label="Heatmap (Normalized)") | |
| with gr.Row(): | |
| metric_comparison = gr.Plot(label="Metric Comparison (Normalized)") | |
| with gr.Row(): | |
| detailed_output = gr.Textbox(label="Detailed Output (RAW metrics + normalized mass)", lines=25) | |
| status_output = gr.Textbox(label="Status", lines=2) | |
| gr.Examples( | |
| examples=[ | |
| ["Climate change is a serious threat", "", list(ANSWER_FORMATS.keys()), "geometric_mean"], | |
| ["Immigration has positive economic effects", "", list(ANSWER_FORMATS.keys()), "geometric_mean"], | |
| ["Government should provide universal healthcare", "", list(ANSWER_FORMATS.keys()), "geometric_mean"], | |
| ["Artificial intelligence will benefit humanity", "You are a tech entrepreneur", list(ANSWER_FORMATS.keys()), "geometric_mean"], | |
| ["Traditional family values are important", "You are a progressive activist", list(ANSWER_FORMATS.keys()), "first_token_prob"], | |
| ], | |
| inputs=[statement_input, persona_input, format_selector, metric_selector], | |
| ) | |
| analyze_btn.click( | |
| fn=analyze_all_formats, | |
| inputs=[statement_input, persona_input, format_selector, metric_selector], | |
| outputs=[comparison_plot, heatmap_plot, metric_comparison, detailed_output, status_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |