import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import matplotlib.pyplot as plt # Initialize model MODEL_NAME = "microsoft/Phi-4-mini-instruct" print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", device_map="auto", trust_remote_code=True ) # Get token IDs for numbers 1-5 likert_tokens = {} for i in range(1, 6): tokens = tokenizer.encode(str(i), add_special_tokens=False) if tokens: likert_tokens[i] = tokens[0] def create_probability_plot(likert_probs, persona=""): """Create a bar chart for Likert scale probabilities""" fig, ax = plt.subplots(figsize=(8, 5)) values = list(likert_probs.keys()) probabilities = list(likert_probs.values()) bars = ax.bar(values, probabilities, color='steelblue', alpha=0.8, edgecolor='navy') # Add value labels for bar, prob in zip(bars, probabilities): height = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{prob:.3f}', ha='center', va='bottom') ax.set_xlabel('Likert Scale Value') ax.set_ylabel('Probability') title = 'Response Probability Distribution' if persona.strip(): title += f'\nPersona: {persona[:50]}...' if len(persona) > 50 else f'\nPersona: {persona}' ax.set_title(title) ax.set_xticks(values) ax.set_ylim(0, max(probabilities) * 1.2 if probabilities else 1) ax.grid(True, axis='y', alpha=0.3) plt.tight_layout() return fig def analyze_with_persona(statement, persona=""): """ Analyze with persona prompt """ try: # read default prompt with open("default-prompt.txt", "r") as f: default_prompt = f.read().strip() # Create chat messages with optional system prompt messages = [] if persona.strip(): messages.append({"role": "system", "content": persona.strip()}) messages.append({ "role": "user", "content": default_prompt.format(statement=statement.strip()) }) # Apply chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = tokenizer(prompt, return_tensors="pt") # Generate with output scores with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, do_sample=False, pad_token_id=tokenizer.eos_token_id ) # Get probabilities for first generated token if outputs.scores: logits = outputs.scores[0][0] # First token, first batch probs = torch.softmax(logits, dim=-1) # Extract Likert probabilities likert_probs = {} outside_probabability = 1.0 for value, token_id in likert_tokens.items(): likert_probs[value] = probs[token_id].item() outside_probabability -= likert_probs[value] # Create probability plot fig = create_probability_plot(likert_probs, persona) # Format probabilities text prob_text = "\n".join([f"{k}: {v:.4f}" for k, v in likert_probs.items()]) prob_text += f"\nLogit probabilities outside of 1-5: {outside_probabability}" # Show what the model actually generated including input and special tokens debug_info = f"{tokenizer.decode(outputs.sequences[0], skip_special_tokens=False)}" return fig, prob_text, f"{debug_info}" else: return None, "", "❌ No scores generated" except Exception as e: return None, "", f"❌ Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="The Unsampled Truth") as demo: gr.Markdown(""" # The Unsampled Truth Extract probability distributions for Likert scale responses (1-5) using Phi-4-mini-instruct. """) with gr.Row(): with gr.Column(): statement_input = gr.Textbox( label="Statement to Analyze", placeholder="e.g., Climate change is a serious threat", lines=3 ) persona_input = gr.Textbox( label="Persona (Optional)", placeholder="e.g., You are a conservative voter from rural America", lines=2 ) analyze_btn = gr.Button("Analyze", variant="primary") with gr.Column(): plot_output = gr.Plot(label="Probability Distribution") prob_output = gr.Textbox(label="Raw Probabilities", lines=6) status_output = gr.Textbox(label="Status", lines=3) # Examples gr.Examples( examples=[ ["Climate change is a serious threat", ""], ["Immigration has positive economic effects", ""], ["Government should provide universal healthcare", ""], ["Climate change is a serious threat", "You are a conservative voter from rural America"], ["Immigration has positive economic effects", "You are a progressive voter from a major city"] ], inputs=[statement_input, persona_input] ) analyze_btn.click( fn=analyze_with_persona, inputs=[statement_input, persona_input], outputs=[plot_output, prob_output, status_output] ) if __name__ == "__main__": demo.launch()