import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import pandas as pd import math from plotly import graph_objects as go # Load model and tokenizer model_ids = { "ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT", "ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT" } tokenizers = { name: AutoTokenizer.from_pretrained(path) for name, path in model_ids.items() } models = { name: AutoModelForCausalLM.from_pretrained(path).eval() for name, path in model_ids.items() } # Helper function to format probability def format_prob(prob): """Format probability as percentage with 1 decimal place""" return f"{prob*100:.1f}%" # Helper function to format log probability def format_log_prob(log_prob): """Format log probability with color coding""" return f"{log_prob:.3f}" # Main function: compute token-wise log probabilities and top-k predictions @torch.no_grad() def compare_models(text, top_k=5): if not text.strip(): return None, "⚠️ Please enter some text to analyze" results = {} for model_name in model_ids: tokenizer = tokenizers[model_name] model = models[model_name] # Tokenize input inputs = tokenizer(text, return_tensors="pt") input_ids = inputs["input_ids"] # Get model output logits outputs = model(**inputs) shift_logits = outputs.logits[:, :-1, :] # Align prediction with target shift_labels = input_ids[:, 1:] # Shift labels to match predictions # Compute log probabilities log_probs = F.log_softmax(shift_logits, dim=-1) token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) total_log_prob = token_log_probs.sum().item() tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] # Skip BOS token # Generate top-k predictions for each position (up to first 20 tokens) topk_list = [] confidence_list = [] for i in range(min(20, shift_logits.shape[1])): topk = torch.topk(log_probs[0, i], k=top_k) topk_ids = topk.indices.tolist() topk_scores = topk.values.tolist() topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids) topk_probs = [math.exp(s) for s in topk_scores] # Format top-k predictions with probabilities topk_formatted = [f"{tok} ({format_prob(p)})" for tok, p in zip(topk_tokens, topk_probs)] topk_list.append(", ".join(topk_formatted)) # Calculate confidence (probability of actual token) actual_token_prob = math.exp(token_log_probs[0, i].item()) confidence_list.append(actual_token_prob) # Prepare dataframe for display df = pd.DataFrame({ "Token": tokens[:20], "LogProb": [format_log_prob(float(x)) for x in token_log_probs[0][:20]], "Confidence": [format_prob(x) for x in confidence_list[:20]], f"Top-{top_k} Predictions": topk_list }) results[model_name] = { "df": df, "total_log_prob": total_log_prob, "tokens": tokens[:20], "confidences": confidence_list[:20] } # Create comparison dataframe comparison_df = pd.DataFrame({ "Token": results["ERNIE-4.5-PT"]["df"]["Token"], "ERNIE-4.5-PT": { "LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"], "Confidence": results["ERNIE-4.5-PT"]["df"]["Confidence"], "Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"] }, "ERNIE-4.5-Base-PT": { "LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"], "Confidence": results["ERNIE-4.5-Base-PT"]["df"]["Confidence"], "Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"] } }) # Create visualization fig = go.Figure() # Add confidence bars for both models fig.add_trace(go.Bar( name='ERNIE-4.5-PT', x=results["ERNIE-4.5-PT"]["tokens"], y=results["ERNIE-4.5-PT"]["confidences"], marker_color='royalblue' )) fig.add_trace(go.Bar( name='ERNIE-4.5-Base-PT', x=results["ERNIE-4.5-Base-PT"]["tokens"], y=results["ERNIE-4.5-Base-PT"]["confidences"], marker_color='lightseagreen' )) fig.update_layout( title='Model Confidence Comparison', xaxis_title='Token', yaxis_title='Confidence (Probability)', barmode='group', yaxis=dict(tickformat='.0%', range=[0, 1]), legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ) ) # Create summary pt_logprob = results['ERNIE-4.5-PT']['total_log_prob'] base_logprob = results['ERNIE-4.5-Base-PT']['total_log_prob'] # Determine which model has higher confidence if pt_logprob > base_logprob: better_model = "ERNIE-4.5-PT" difference = pt_logprob - base_logprob else: better_model = "ERNIE-4.5-Base-PT" difference = base_logprob - pt_logprob summary = ( f"📊 **Model Comparison Summary**\n\n" f"**Total Log Probability**:\n" f"- ERNIE-4.5-PT: {pt_logprob:.3f}\n" f"- ERNIE-4.5-Base-PT: {base_logprob:.3f}\n\n" f"🏆 **Higher Confidence Model**: {better_model}\n" f"Difference: {difference:.3f} ({'+' if better_model == 'ERNIE-4.5-PT' else '-'}{difference:.3f})\n\n" f"**What this means**:\n" f"- Log probability closer to 0 (less negative) indicates higher model confidence\n" f"- The {better_model} model is more confident in predicting your input text\n" f"- Confidence per token is shown in the table and chart below" ) return comparison_df, summary, fig # Create custom CSS for better styling css = """ .main-container { max-width: 1200px; margin: 0 auto; } .dataframe-container { margin: 20px 0; } .confidence-chart { margin: 20px 0; height: 400px; } .summary-box { background-color: #f8f9fa; border-left: 4px solid #4285f4; padding: 15px; border-radius: 4px; margin: 20px 0; } .model-header { font-weight: bold; color: #1a73e8; margin-top: 10px; } .token-cell { font-family: monospace; background-color: #f1f3f4; padding: 4px 8px; border-radius: 3px; } .confidence-high { color: #0f9d58; font-weight: bold; } .confidence-medium { color: #f4b400; } .confidence-low { color: #db4437; } """ # Gradio interface with improved layout with gr.Blocks(css=css, title="ERNIE Model Comparison Tool") as demo: gr.Markdown( """ # 🔍 ERNIE 4.5 Model Comparison Tool Compare how different ERNIE models process your text with detailed token-level analysis. ## What this tool shows: - **Token Log Probability**: How confident the model is in predicting each token (closer to 0 is better) - **Confidence**: Probability percentage for each token prediction - **Top-k Predictions**: What other tokens the model considered likely - **Visual Comparison**: Bar chart showing confidence differences between models """ ) with gr.Row(): with gr.Column(scale=3): input_text = gr.Textbox( lines=3, placeholder="Enter text to analyze (e.g., 'Hello, World!')", label="Input Text", value="Hello, World!" ) with gr.Column(scale=1): top_k = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Top-k Predictions" ) with gr.Row(): compare_btn = gr.Button("Compare Models", variant="primary") with gr.Row(): with gr.Column(): summary_box = gr.Markdown( elem_classes=["summary-box"], label="Model Comparison Summary" ) with gr.Row(): with gr.Column(): comparison_table = gr.Dataframe( label="Token-Level Analysis", elem_classes=["dataframe-container"], interactive=False, wrap=True ) with gr.Row(): with gr.Column(): confidence_chart = gr.Plot( label="Model Confidence Comparison", elem_classes=["confidence-chart"] ) # Examples section gr.Examples( examples=[ ["Hello, World!", 3], ["The quick brown fox jumps over the lazy dog.", 5], ["Artificial intelligence will transform our society.", 3], ["What is the meaning of life?", 4] ], inputs=[input_text, top_k], label="Try these examples:" ) # Footer with explanation gr.Markdown( """ ## How to Interpret Results 1. **Log Probability**: Negative values where closer to 0 means higher model confidence 2. **Confidence**: Percentage showing how certain the model was about each token 3. **Top-k Predictions**: Alternative tokens the model considered likely 4. **Visual Chart**: Bar heights represent model confidence for each token **Model Differences**: - **ERNIE-4.5-PT**: Instruction-tuned model, better at following complex instructions - **ERNIE-4.5-Base-PT**: Base model, better at general language patterns """ ) # Set up event handler compare_btn.click( fn=compare_models, inputs=[input_text, top_k], outputs=[comparison_table, summary_box, confidence_chart] ) if __name__ == "__main__": demo.launch()