nsschw's picture
Refactor debug information in analyze_with_persona function to remove leading newline from output
c87cd16
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
# Initialize model
MODEL_NAME = "microsoft/Phi-4-mini-instruct"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
# Get token IDs for numbers 1-5
likert_tokens = {}
for i in range(1, 6):
tokens = tokenizer.encode(str(i), add_special_tokens=False)
if tokens:
likert_tokens[i] = tokens[0]
def create_probability_plot(likert_probs, persona=""):
"""Create a bar chart for Likert scale probabilities"""
fig, ax = plt.subplots(figsize=(8, 5))
values = list(likert_probs.keys())
probabilities = list(likert_probs.values())
bars = ax.bar(values, probabilities, color='steelblue', alpha=0.8, edgecolor='navy')
# Add value labels
for bar, prob in zip(bars, probabilities):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f'{prob:.3f}', ha='center', va='bottom')
ax.set_xlabel('Likert Scale Value')
ax.set_ylabel('Probability')
title = 'Response Probability Distribution'
if persona.strip():
title += f'\nPersona: {persona[:50]}...' if len(persona) > 50 else f'\nPersona: {persona}'
ax.set_title(title)
ax.set_xticks(values)
ax.set_ylim(0, max(probabilities) * 1.2 if probabilities else 1)
ax.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
return fig
def analyze_with_persona(statement, persona=""):
"""
Analyze with persona prompt
"""
try:
# read default prompt
with open("default-prompt.txt", "r") as f:
default_prompt = f.read().strip()
# Create chat messages with optional system prompt
messages = []
if persona.strip():
messages.append({"role": "system", "content": persona.strip()})
messages.append({
"role": "user",
"content": default_prompt.format(statement=statement.strip())
})
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
# Generate with output scores
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
# Get probabilities for first generated token
if outputs.scores:
logits = outputs.scores[0][0] # First token, first batch
probs = torch.softmax(logits, dim=-1)
# Extract Likert probabilities
likert_probs = {}
outside_probabability = 1.0
for value, token_id in likert_tokens.items():
likert_probs[value] = probs[token_id].item()
outside_probabability -= likert_probs[value]
# Create probability plot
fig = create_probability_plot(likert_probs, persona)
# Format probabilities text
prob_text = "\n".join([f"{k}: {v:.4f}" for k, v in likert_probs.items()])
prob_text += f"\nLogit probabilities outside of 1-5: {outside_probabability}"
# Show what the model actually generated including input and special tokens
debug_info = f"{tokenizer.decode(outputs.sequences[0], skip_special_tokens=False)}"
return fig, prob_text, f"{debug_info}"
else:
return None, "", "❌ No scores generated"
except Exception as e:
return None, "", f"❌ Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="The Unsampled Truth") as demo:
gr.Markdown("""
# The Unsampled Truth
Extract probability distributions for Likert scale responses (1-5) using Phi-4-mini-instruct.
""")
with gr.Row():
with gr.Column():
statement_input = gr.Textbox(
label="Statement to Analyze",
placeholder="e.g., Climate change is a serious threat",
lines=3
)
persona_input = gr.Textbox(
label="Persona (Optional)",
placeholder="e.g., You are a conservative voter from rural America",
lines=2
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
plot_output = gr.Plot(label="Probability Distribution")
prob_output = gr.Textbox(label="Raw Probabilities", lines=6)
status_output = gr.Textbox(label="Status", lines=3)
# Examples
gr.Examples(
examples=[
["Climate change is a serious threat", ""],
["Immigration has positive economic effects", ""],
["Government should provide universal healthcare", ""],
["Climate change is a serious threat", "You are a conservative voter from rural America"],
["Immigration has positive economic effects", "You are a progressive voter from a major city"]
],
inputs=[statement_input, persona_input]
)
analyze_btn.click(
fn=analyze_with_persona,
inputs=[statement_input, persona_input],
outputs=[plot_output, prob_output, status_output]
)
if __name__ == "__main__":
demo.launch()