import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM import numpy as np import pandas as pd import spacy from spacy import displacy import math import warnings try: from config import DEFAULT_MODELS, MODEL_SETTINGS, VIZ_SETTINGS, PROCESSING_SETTINGS, UI_SETTINGS, ERROR_MESSAGES except ImportError: # Fallback configuration if config.py is not available DEFAULT_MODELS = { "decoder": ["gpt2", "distilgpt2"], "encoder": ["bert-base-uncased", "distilbert-base-uncased"] } MODEL_SETTINGS = {"max_length": 512} VIZ_SETTINGS = { "max_perplexity_display": 100.0, "color_scheme": { "high_perplexity": {"r": 255, "g": 0, "b": 50}, "low_perplexity": {"r": 0, "g": 255, "b": 50} }, "displacy_options": {"ents": ["PP"], "colors": {}} } PROCESSING_SETTINGS = { "default_iterations": 1, "max_iterations": 10, "default_mlm_probability": 0.15, "min_mlm_probability": 0.1, "max_mlm_probability": 0.5, "epsilon": 1e-10 } UI_SETTINGS = { "title": "📈 Perplexity Viewer", "description": "Visualize per-token perplexity using color gradients.", "examples": [ {"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder", "iterations": 1, "mlm_prob": 0.15}, {"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder", "iterations": 1, "mlm_prob": 0.15} ] } ERROR_MESSAGES = { "empty_text": "Please enter some text to analyze.", "model_load_error": "Error loading model {model_name}: {error}", "processing_error": "Error processing text: {error}" } warnings.filterwarnings("ignore") # Global variables to cache models cached_models = {} cached_tokenizers = {} def load_model_and_tokenizer(model_name, model_type): """Load and cache model and tokenizer""" cache_key = f"{model_name}_{model_type}" if cache_key not in cached_models: try: tokenizer = AutoTokenizer.from_pretrained(model_name) # Add pad token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if model_type == "decoder": model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) else: # encoder model = AutoModelForMaskedLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) model.eval() # Set to evaluation mode cached_models[cache_key] = model cached_tokenizers[cache_key] = tokenizer return model, tokenizer except Exception as e: raise gr.Error(ERROR_MESSAGES["model_load_error"].format(model_name=model_name, error=str(e))) return cached_models[cache_key], cached_tokenizers[cache_key] def calculate_decoder_perplexity(text, model, tokenizer, iterations=1): """Calculate perplexity for decoder models (like GPT)""" device = next(model.parameters()).device perplexities = [] for iteration in range(iterations): # Tokenize the text inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"]) input_ids = inputs.input_ids.to(device) if input_ids.size(1) < 2: raise gr.Error("Text is too short for perplexity calculation.") with torch.no_grad(): outputs = model(input_ids, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss).item() perplexities.append(perplexity) # Get token-level perplexities for the last iteration with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits # Shift logits and labels for next token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() # Calculate per-token losses loss_fct = torch.nn.CrossEntropyLoss(reduction='none') token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) token_perplexities = torch.exp(token_losses).cpu().numpy() # Get tokens (excluding the first one since we predict next tokens) tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:]) # Clean up tokens for display cleaned_tokens = [] for token in tokens: if token.startswith('Ġ'): cleaned_tokens.append(token[1:]) # Remove Ġ prefix elif token.startswith('##'): cleaned_tokens.append(token[2:]) # Remove ## prefix else: cleaned_tokens.append(token) return np.mean(perplexities), cleaned_tokens, token_perplexities def calculate_encoder_perplexity(text, model, tokenizer, mlm_probability=0.15, iterations=1): """Calculate pseudo-perplexity for encoder models (like BERT) using MLM""" device = next(model.parameters()).device perplexities = [] for iteration in range(iterations): # Tokenize the text inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"]) input_ids = inputs.input_ids.to(device) if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP] raise gr.Error("Text is too short for MLM perplexity calculation.") # Create a copy for masking masked_input_ids = input_ids.clone() original_tokens = input_ids.clone() # Randomly mask tokens (excluding special tokens) seq_length = input_ids.size(1) mask_indices = [] special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} for i in range(seq_length): if input_ids[0, i].item() not in special_token_ids: if torch.rand(1).item() < mlm_probability: mask_indices.append(i) masked_input_ids[0, i] = tokenizer.mask_token_id if not mask_indices: # If no tokens were masked, mask at least one non-special token non_special_indices = [i for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids] if non_special_indices: mask_idx = torch.randint(0, len(non_special_indices), (1,)).item() mask_indices = [non_special_indices[mask_idx]] masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id with torch.no_grad(): outputs = model(masked_input_ids) predictions = outputs.logits # Calculate perplexity for masked tokens masked_token_losses = [] for idx in mask_indices: target_id = original_tokens[0, idx] pred_scores = predictions[0, idx] prob = F.softmax(pred_scores, dim=-1)[target_id] loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"]) masked_token_losses.append(loss.item()) if masked_token_losses: avg_loss = np.mean(masked_token_losses) perplexity = math.exp(avg_loss) perplexities.append(perplexity) # Calculate per-token pseudo-perplexity for visualization with torch.no_grad(): token_perplexities = [] tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} for i in range(len(tokens)): if input_ids[0, i].item() in special_token_ids: token_perplexities.append(1.0) # Low perplexity for special tokens else: masked_input = input_ids.clone() original_token_id = input_ids[0, i] masked_input[0, i] = tokenizer.mask_token_id outputs = model(masked_input) predictions = outputs.logits[0, i] prob = F.softmax(predictions, dim=-1)[original_token_id] token_perplexity = 1.0 / (prob.item() + PROCESSING_SETTINGS["epsilon"]) token_perplexities.append(token_perplexity) # Clean up tokens for display cleaned_tokens = [] for token in tokens: if token.startswith('##'): cleaned_tokens.append(token[2:]) else: cleaned_tokens.append(token) return np.mean(perplexities) if perplexities else float('inf'), cleaned_tokens, np.array(token_perplexities) def create_visualization(tokens, perplexities): """Create displaCy visualization with color-coded perplexities""" if len(tokens) == 0: return "
No tokens to visualize.
" # Cap perplexities for better visualization max_perplexity = min(np.max(perplexities), VIZ_SETTINGS["max_perplexity_display"]) # Normalize perplexities to 0-1 range for color mapping normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1) # Create entities for displaCy entities = [] text_parts = [] current_pos = 0 for i, (token, perp, norm_perp) in enumerate(zip(tokens, perplexities, normalized_perplexities)): # Skip empty tokens if not token.strip(): continue # Clean token for display clean_token = token.replace("", "").strip() if not clean_token: continue # Add space before token if it's not the first one and doesn't start with punctuation if i > 0 and not clean_token[0] in ".,!?;:": text_parts.append(" ") current_pos += 1 text_parts.append(clean_token) # Map perplexity to color high_color = VIZ_SETTINGS["color_scheme"]["high_perplexity"] low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"] red = int(high_color["r"] * norm_perp + low_color["r"] * (1 - norm_perp)) green = int(high_color["g"] * norm_perp + low_color["g"] * (1 - norm_perp)) blue = int(high_color["b"] * norm_perp + low_color["b"] * (1 - norm_perp)) color = f"rgb({red}, {green}, {blue})" entities.append({ "start": current_pos, "end": current_pos + len(clean_token), "label": f"{perp:.2f}", "color": color }) current_pos += len(clean_token) # Join text parts text = "".join(text_parts) if not entities: return "No valid tokens found for visualization.
" # Create displaCy data structure doc_data = { "text": text, "ents": entities, "title": "Per-token Perplexity Visualization" } try: # Generate HTML html = displacy.render(doc_data, style="ent", manual=True, options=VIZ_SETTINGS["displacy_options"]) return html except Exception as e: return f"Error creating visualization: {str(e)}
" def process_text(text, model_name, model_type, iterations, mlm_probability): """Main processing function""" if not text.strip(): return ERROR_MESSAGES["empty_text"], "", pd.DataFrame() try: # Validate inputs iterations = max(1, min(iterations, PROCESSING_SETTINGS["max_iterations"])) mlm_probability = max(PROCESSING_SETTINGS["min_mlm_probability"], min(mlm_probability, PROCESSING_SETTINGS["max_mlm_probability"])) # Load model and tokenizer model, tokenizer = load_model_and_tokenizer(model_name, model_type) # Calculate perplexity if model_type == "decoder": avg_perplexity, tokens, token_perplexities = calculate_decoder_perplexity( text, model, tokenizer, iterations ) else: # encoder avg_perplexity, tokens, token_perplexities = calculate_encoder_perplexity( text, model, tokenizer, mlm_probability, iterations ) # Create visualization viz_html = create_visualization(tokens, token_perplexities) # Create summary summary = f""" ### Analysis Results **Model:** `{model_name}` **Model Type:** {model_type.title()} **Average Perplexity:** {avg_perplexity:.4f} **Number of Tokens:** {len(tokens)} **Iterations:** {iterations} """ if model_type == "encoder": summary += f" \n**MLM Probability:** {mlm_probability}" # Create detailed results table df = pd.DataFrame({ 'Token': tokens, 'Perplexity': [f"{p:.4f}" for p in token_perplexities] }) return summary, viz_html, df except Exception as e: error_msg = ERROR_MESSAGES["processing_error"].format(error=str(e)) return error_msg, "", pd.DataFrame() # Create Gradio interface with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {UI_SETTINGS['title']}") gr.Markdown(UI_SETTINGS["description"]) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Input Text", placeholder="Enter the text you want to analyze...", lines=6, max_lines=10 ) with gr.Row(): model_name = gr.Dropdown( label="Model Name", choices=DEFAULT_MODELS["decoder"] + DEFAULT_MODELS["encoder"], value="gpt2", allow_custom_value=True, info="Select a model or enter a custom HuggingFace model name" ) model_type = gr.Radio( label="Model Type", choices=["decoder", "encoder"], value="decoder", info="Decoder for causal LM, Encoder for masked LM" ) with gr.Row(): iterations = gr.Slider( label="Iterations", minimum=1, maximum=PROCESSING_SETTINGS["max_iterations"], value=PROCESSING_SETTINGS["default_iterations"], step=1, info="Number of iterations to average over" ) mlm_probability = gr.Slider( label="MLM Probability", minimum=PROCESSING_SETTINGS["min_mlm_probability"], maximum=PROCESSING_SETTINGS["max_mlm_probability"], value=PROCESSING_SETTINGS["default_mlm_probability"], step=0.05, visible=False, info="Probability of masking tokens (encoder models only)" ) analyze_btn = gr.Button("🔍 Analyze Perplexity", variant="primary", size="lg") with gr.Column(scale=3): summary_output = gr.Markdown(label="Summary") viz_output = gr.HTML(label="Perplexity Visualization") # Full-width table with gr.Row(): table_output = gr.Dataframe( label="Detailed Token Results", interactive=False, wrap=True ) # Update model dropdown based on type selection def update_model_choices(model_type): return gr.update(choices=DEFAULT_MODELS[model_type], value=DEFAULT_MODELS[model_type][0]) # Show/hide MLM probability based on model type def toggle_mlm_visibility(model_type): return gr.update(visible=(model_type == "encoder")) model_type.change( fn=lambda mt: [update_model_choices(mt), toggle_mlm_visibility(mt)], inputs=[model_type], outputs=[model_name, mlm_probability] ) # Set up the analysis function analyze_btn.click( fn=process_text, inputs=[text_input, model_name, model_type, iterations, mlm_probability], outputs=[summary_output, viz_output, table_output] ) # Add examples with gr.Accordion("📝 Example Texts", open=False): examples_data = [ [ex["text"], ex["model"], ex["type"], ex["iterations"], ex["mlm_prob"]] for ex in UI_SETTINGS["examples"] ] gr.Examples( examples=examples_data, inputs=[text_input, model_name, model_type, iterations, mlm_probability], outputs=[summary_output, viz_output, table_output], fn=process_text, cache_examples=False, label="Click on an example to try it out:" ) # Add footer with information gr.Markdown(""" --- ### 📊 How it works: - **Decoder Models** (GPT, etc.): Calculate true perplexity by measuring how well the model predicts the next token - **Encoder Models** (BERT, etc.): Calculate pseudo-perplexity using masked language modeling (MLM) - **Color Coding**: Red = High perplexity (uncertain), Green = Low perplexity (confident) ### ⚠️ Notes: - First model load may take some time - Models are cached after first use - Very long texts are truncated to 512 tokens - GPU acceleration is used when available """) if __name__ == "__main__": try: demo.launch( server_name="0.0.0.0", server_port=7860, show_api=False ) except Exception as e: print(f"❌ Failed to launch app: {e}") print("💡 Try running with: python run.py") # Fallback to basic launch try: demo.launch() except Exception as fallback_error: print(f"❌ Fallback launch also failed: {fallback_error}") print("💡 Try updating Gradio: pip install --upgrade gradio")