import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM import numpy as np import pandas as pd import spacy from spacy import displacy import math import warnings try: from config import DEFAULT_MODELS, MODEL_SETTINGS, VIZ_SETTINGS, PROCESSING_SETTINGS, UI_SETTINGS, ERROR_MESSAGES except ImportError: # Fallback configuration if config.py is not available DEFAULT_MODELS = { "decoder": ["gpt2", "distilgpt2"], "encoder": ["bert-base-uncased", "distilbert-base-uncased"] } MODEL_SETTINGS = {"max_length": 512} VIZ_SETTINGS = { "max_perplexity_display": 50.0, "color_scheme": { "low_perplexity": {"r": 46, "g": 204, "b": 113}, "medium_perplexity": {"r": 241, "g": 196, "b": 15}, "high_perplexity": {"r": 231, "g": 76, "b": 60}, "background_alpha": 0.7, "border_alpha": 0.9 }, "thresholds": { "low_threshold": 0.3, "high_threshold": 0.7 }, "displacy_options": {"ents": ["PP"], "colors": {}} } PROCESSING_SETTINGS = { "epsilon": 1e-10 } UI_SETTINGS = { "title": "📈 Perplexity Viewer", "description": "Visualize per-token perplexity using color gradients.", "examples": [ {"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder"}, {"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder"}, {"text": "Quantum entanglement defies classical physics intuition completely.", "model": "distilgpt2", "type": "decoder"}, {"text": "Machine learning algorithms require computational resources.", "model": "distilbert-base-uncased", "type": "encoder"} ] } ERROR_MESSAGES = { "empty_text": "Please enter some text to analyze.", "model_load_error": "Error loading model {model_name}: {error}", "processing_error": "Error processing text: {error}" } warnings.filterwarnings("ignore") # Global variables to cache models cached_models = {} cached_tokenizers = {} def load_model_and_tokenizer(model_name, model_type): """Load and cache model and tokenizer""" cache_key = f"{model_name}_{model_type}" if cache_key not in cached_models: try: tokenizer = AutoTokenizer.from_pretrained(model_name) # Add pad token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if model_type == "decoder": model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) else: # encoder model = AutoModelForMaskedLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) model.eval() # Set to evaluation mode cached_models[cache_key] = model cached_tokenizers[cache_key] = tokenizer return model, tokenizer except Exception as e: raise gr.Error(ERROR_MESSAGES["model_load_error"].format(model_name=model_name, error=str(e))) return cached_models[cache_key], cached_tokenizers[cache_key] def calculate_decoder_perplexity(text, model, tokenizer): """Calculate perplexity for decoder models (like GPT)""" device = next(model.parameters()).device # Tokenize the text inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"]) input_ids = inputs.input_ids.to(device) if input_ids.size(1) < 2: raise gr.Error("Text is too short for perplexity calculation.") # Calculate overall perplexity with torch.no_grad(): outputs = model(input_ids, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss).item() # Get token-level perplexities with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits # Shift logits and labels for next token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() # Calculate per-token losses loss_fct = torch.nn.CrossEntropyLoss(reduction='none') token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) token_perplexities = torch.exp(token_losses).cpu().numpy() # Get tokens (excluding the first one since we predict next tokens) tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:]) # Clean up tokens for display cleaned_tokens = [] for token in tokens: if token.startswith('Ġ'): cleaned_tokens.append(token[1:]) # Remove Ġ prefix elif token.startswith('##'): cleaned_tokens.append(token[2:]) # Remove ## prefix else: cleaned_tokens.append(token) return perplexity, cleaned_tokens, token_perplexities def calculate_encoder_perplexity(text, model, tokenizer): """Calculate pseudo-perplexity for encoder models (like BERT) using MLM on all tokens""" device = next(model.parameters()).device # Tokenize the text inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"]) input_ids = inputs.input_ids.to(device) if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP] raise gr.Error("Text is too short for MLM perplexity calculation.") # Calculate average perplexity by masking all content tokens with torch.no_grad(): seq_length = input_ids.size(1) special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} all_token_losses = [] # Mask each non-special token individually and calculate loss for i in range(seq_length): if input_ids[0, i].item() not in special_token_ids: masked_input = input_ids.clone() original_token_id = input_ids[0, i] masked_input[0, i] = tokenizer.mask_token_id outputs = model(masked_input) predictions = outputs.logits[0, i] prob = F.softmax(predictions, dim=-1)[original_token_id] loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"]) all_token_losses.append(loss.item()) if all_token_losses: avg_loss = np.mean(all_token_losses) perplexity = math.exp(avg_loss) else: perplexity = float('inf') # Calculate per-token pseudo-perplexity for visualization (analyze all tokens) with torch.no_grad(): token_perplexities = [] tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} for i in range(len(tokens)): if input_ids[0, i].item() in special_token_ids: token_perplexities.append(1.0) # Low perplexity for special tokens else: # Calculate detailed perplexity for every content token masked_input = input_ids.clone() original_token_id = input_ids[0, i] masked_input[0, i] = tokenizer.mask_token_id outputs = model(masked_input) predictions = outputs.logits[0, i] prob = F.softmax(predictions, dim=-1)[original_token_id] token_perplexity = 1.0 / (prob.item() + PROCESSING_SETTINGS["epsilon"]) token_perplexities.append(token_perplexity) # Clean up tokens for display cleaned_tokens = [] for token in tokens: if token.startswith('##'): cleaned_tokens.append(token[2:]) else: cleaned_tokens.append(token) return perplexity, cleaned_tokens, np.array(token_perplexities) def create_visualization(tokens, perplexities): """Create custom HTML visualization with color-coded perplexities""" if len(tokens) == 0: return "
No tokens to visualize.
" # Cap perplexities for better visualization max_perplexity = min(np.max(perplexities), VIZ_SETTINGS["max_perplexity_display"]) # Normalize perplexities to 0-1 range for color mapping normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1) # Create HTML with inline styles for color coding html_parts = [ '