Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import matplotlib.pyplot as plt | |
| # Initialize model | |
| MODEL_NAME = "microsoft/Phi-4-mini-instruct" | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # Get token IDs for numbers 1-5 | |
| likert_tokens = {} | |
| for i in range(1, 6): | |
| tokens = tokenizer.encode(str(i), add_special_tokens=False) | |
| if tokens: | |
| likert_tokens[i] = tokens[0] | |
| def create_probability_plot(likert_probs, persona=""): | |
| """Create a bar chart for Likert scale probabilities""" | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| values = list(likert_probs.keys()) | |
| probabilities = list(likert_probs.values()) | |
| bars = ax.bar(values, probabilities, color='steelblue', alpha=0.8, edgecolor='navy') | |
| # Add value labels | |
| for bar, prob in zip(bars, probabilities): | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
| f'{prob:.3f}', ha='center', va='bottom') | |
| ax.set_xlabel('Likert Scale Value') | |
| ax.set_ylabel('Probability') | |
| title = 'Response Probability Distribution' | |
| if persona.strip(): | |
| title += f'\nPersona: {persona[:50]}...' if len(persona) > 50 else f'\nPersona: {persona}' | |
| ax.set_title(title) | |
| ax.set_xticks(values) | |
| ax.set_ylim(0, max(probabilities) * 1.2 if probabilities else 1) | |
| ax.grid(True, axis='y', alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def analyze_with_persona(statement, persona=""): | |
| """ | |
| Analyze with persona prompt | |
| """ | |
| try: | |
| # read default prompt | |
| with open("default-prompt.txt", "r") as f: | |
| default_prompt = f.read().strip() | |
| # Create chat messages with optional system prompt | |
| messages = [] | |
| if persona.strip(): | |
| messages.append({"role": "system", "content": persona.strip()}) | |
| messages.append({ | |
| "role": "user", | |
| "content": default_prompt.format(statement=statement.strip()) | |
| }) | |
| # Apply chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| # Generate with output scores | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Get probabilities for first generated token | |
| if outputs.scores: | |
| logits = outputs.scores[0][0] # First token, first batch | |
| probs = torch.softmax(logits, dim=-1) | |
| # Extract Likert probabilities | |
| likert_probs = {} | |
| outside_probabability = 1.0 | |
| for value, token_id in likert_tokens.items(): | |
| likert_probs[value] = probs[token_id].item() | |
| outside_probabability -= likert_probs[value] | |
| # Create probability plot | |
| fig = create_probability_plot(likert_probs, persona) | |
| # Format probabilities text | |
| prob_text = "\n".join([f"{k}: {v:.4f}" for k, v in likert_probs.items()]) | |
| prob_text += f"\nLogit probabilities outside of 1-5: {outside_probabability}" | |
| # Show what the model actually generated including input and special tokens | |
| debug_info = f"{tokenizer.decode(outputs.sequences[0], skip_special_tokens=False)}" | |
| return fig, prob_text, f"{debug_info}" | |
| else: | |
| return None, "", "❌ No scores generated" | |
| except Exception as e: | |
| return None, "", f"❌ Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="The Unsampled Truth") as demo: | |
| gr.Markdown(""" | |
| # The Unsampled Truth | |
| Extract probability distributions for Likert scale responses (1-5) using Phi-4-mini-instruct. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| statement_input = gr.Textbox( | |
| label="Statement to Analyze", | |
| placeholder="e.g., Climate change is a serious threat", | |
| lines=3 | |
| ) | |
| persona_input = gr.Textbox( | |
| label="Persona (Optional)", | |
| placeholder="e.g., You are a conservative voter from rural America", | |
| lines=2 | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(): | |
| plot_output = gr.Plot(label="Probability Distribution") | |
| prob_output = gr.Textbox(label="Raw Probabilities", lines=6) | |
| status_output = gr.Textbox(label="Status", lines=3) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Climate change is a serious threat", ""], | |
| ["Immigration has positive economic effects", ""], | |
| ["Government should provide universal healthcare", ""], | |
| ["Climate change is a serious threat", "You are a conservative voter from rural America"], | |
| ["Immigration has positive economic effects", "You are a progressive voter from a major city"] | |
| ], | |
| inputs=[statement_input, persona_input] | |
| ) | |
| analyze_btn.click( | |
| fn=analyze_with_persona, | |
| inputs=[statement_input, persona_input], | |
| outputs=[plot_output, prob_output, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |