Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM | |
| import numpy as np | |
| import pandas as pd | |
| import spacy | |
| from spacy import displacy | |
| import math | |
| import warnings | |
| try: | |
| from config import DEFAULT_MODELS, MODEL_SETTINGS, VIZ_SETTINGS, PROCESSING_SETTINGS, UI_SETTINGS, ERROR_MESSAGES | |
| except ImportError: | |
| # Fallback configuration if config.py is not available | |
| DEFAULT_MODELS = { | |
| "decoder": ["gpt2", "distilgpt2"], | |
| "encoder": ["bert-base-uncased", "distilbert-base-uncased"] | |
| } | |
| MODEL_SETTINGS = {"max_length": 512} | |
| VIZ_SETTINGS = { | |
| "max_perplexity_display": 50.0, | |
| "color_scheme": { | |
| "low_perplexity": {"r": 46, "g": 204, "b": 113}, | |
| "medium_perplexity": {"r": 241, "g": 196, "b": 15}, | |
| "high_perplexity": {"r": 231, "g": 76, "b": 60}, | |
| "background_alpha": 0.7, | |
| "border_alpha": 0.9 | |
| }, | |
| "thresholds": { | |
| "low_threshold": 0.3, | |
| "high_threshold": 0.7 | |
| }, | |
| "displacy_options": {"ents": ["PP"], "colors": {}} | |
| } | |
| PROCESSING_SETTINGS = { | |
| "epsilon": 1e-10 | |
| } | |
| UI_SETTINGS = { | |
| "title": "π Perplexity Viewer", | |
| "description": "Visualize per-token perplexity using color gradients.", | |
| "examples": [ | |
| {"text": "The quick brown fox jumps over the lazy dog.", "model": "gpt2", "type": "decoder"}, | |
| {"text": "The capital of France is Paris.", "model": "bert-base-uncased", "type": "encoder"}, | |
| {"text": "Quantum entanglement defies classical physics intuition completely.", "model": "distilgpt2", "type": "decoder"}, | |
| {"text": "Machine learning algorithms require computational resources.", "model": "distilbert-base-uncased", "type": "encoder"} | |
| ] | |
| } | |
| ERROR_MESSAGES = { | |
| "empty_text": "Please enter some text to analyze.", | |
| "model_load_error": "Error loading model {model_name}: {error}", | |
| "processing_error": "Error processing text: {error}" | |
| } | |
| warnings.filterwarnings("ignore") | |
| # Global variables to cache models | |
| cached_models = {} | |
| cached_tokenizers = {} | |
| def load_model_and_tokenizer(model_name, model_type): | |
| """Load and cache model and tokenizer""" | |
| cache_key = f"{model_name}_{model_type}" | |
| if cache_key not in cached_models: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Add pad token if it doesn't exist | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if model_type == "decoder": | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True | |
| ) | |
| else: # encoder | |
| model = AutoModelForMaskedLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True | |
| ) | |
| model.eval() # Set to evaluation mode | |
| cached_models[cache_key] = model | |
| cached_tokenizers[cache_key] = tokenizer | |
| return model, tokenizer | |
| except Exception as e: | |
| raise gr.Error(ERROR_MESSAGES["model_load_error"].format(model_name=model_name, error=str(e))) | |
| return cached_models[cache_key], cached_tokenizers[cache_key] | |
| def calculate_decoder_perplexity(text, model, tokenizer): | |
| """Calculate perplexity for decoder models (like GPT)""" | |
| device = next(model.parameters()).device | |
| # Tokenize the text | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"]) | |
| input_ids = inputs.input_ids.to(device) | |
| if input_ids.size(1) < 2: | |
| raise gr.Error("Text is too short for perplexity calculation.") | |
| # Calculate overall perplexity | |
| with torch.no_grad(): | |
| outputs = model(input_ids, labels=input_ids) | |
| loss = outputs.loss | |
| perplexity = torch.exp(loss).item() | |
| # Get token-level perplexities | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs.logits | |
| # Shift logits and labels for next token prediction | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = input_ids[..., 1:].contiguous() | |
| # Calculate per-token losses | |
| loss_fct = torch.nn.CrossEntropyLoss(reduction='none') | |
| token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| token_perplexities = torch.exp(token_losses).cpu().numpy() | |
| # Get tokens (excluding the first one since we predict next tokens) | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:]) | |
| # Clean up tokens for display | |
| cleaned_tokens = [] | |
| for token in tokens: | |
| if token.startswith('Δ '): | |
| cleaned_tokens.append(token[1:]) # Remove Δ prefix | |
| elif token.startswith('##'): | |
| cleaned_tokens.append(token[2:]) # Remove ## prefix | |
| else: | |
| cleaned_tokens.append(token) | |
| return perplexity, cleaned_tokens, token_perplexities | |
| def calculate_encoder_perplexity(text, model, tokenizer): | |
| """Calculate pseudo-perplexity for encoder models (like BERT) using MLM on all tokens""" | |
| device = next(model.parameters()).device | |
| # Tokenize the text | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MODEL_SETTINGS["max_length"]) | |
| input_ids = inputs.input_ids.to(device) | |
| if input_ids.size(1) < 3: # Need at least [CLS] + 1 token + [SEP] | |
| raise gr.Error("Text is too short for MLM perplexity calculation.") | |
| # Calculate average perplexity by masking all content tokens | |
| with torch.no_grad(): | |
| seq_length = input_ids.size(1) | |
| special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} | |
| all_token_losses = [] | |
| # Mask each non-special token individually and calculate loss | |
| for i in range(seq_length): | |
| if input_ids[0, i].item() not in special_token_ids: | |
| masked_input = input_ids.clone() | |
| original_token_id = input_ids[0, i] | |
| masked_input[0, i] = tokenizer.mask_token_id | |
| outputs = model(masked_input) | |
| predictions = outputs.logits[0, i] | |
| prob = F.softmax(predictions, dim=-1)[original_token_id] | |
| loss = -torch.log(prob + PROCESSING_SETTINGS["epsilon"]) | |
| all_token_losses.append(loss.item()) | |
| if all_token_losses: | |
| avg_loss = np.mean(all_token_losses) | |
| perplexity = math.exp(avg_loss) | |
| else: | |
| perplexity = float('inf') | |
| # Calculate per-token pseudo-perplexity for visualization (analyze all tokens) | |
| with torch.no_grad(): | |
| token_perplexities = [] | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} | |
| for i in range(len(tokens)): | |
| if input_ids[0, i].item() in special_token_ids: | |
| token_perplexities.append(1.0) # Low perplexity for special tokens | |
| else: | |
| # Calculate detailed perplexity for every content token | |
| masked_input = input_ids.clone() | |
| original_token_id = input_ids[0, i] | |
| masked_input[0, i] = tokenizer.mask_token_id | |
| outputs = model(masked_input) | |
| predictions = outputs.logits[0, i] | |
| prob = F.softmax(predictions, dim=-1)[original_token_id] | |
| token_perplexity = 1.0 / (prob.item() + PROCESSING_SETTINGS["epsilon"]) | |
| token_perplexities.append(token_perplexity) | |
| # Clean up tokens for display | |
| cleaned_tokens = [] | |
| for token in tokens: | |
| if token.startswith('##'): | |
| cleaned_tokens.append(token[2:]) | |
| else: | |
| cleaned_tokens.append(token) | |
| return perplexity, cleaned_tokens, np.array(token_perplexities) | |
| def create_visualization(tokens, perplexities): | |
| """Create custom HTML visualization with color-coded perplexities""" | |
| if len(tokens) == 0: | |
| return "<p>No tokens to visualize.</p>" | |
| # Cap perplexities for better visualization | |
| max_perplexity = min(np.max(perplexities), VIZ_SETTINGS["max_perplexity_display"]) | |
| # Normalize perplexities to 0-1 range for color mapping | |
| normalized_perplexities = np.clip(perplexities / max_perplexity, 0, 1) | |
| # Create HTML with inline styles for color coding | |
| html_parts = [ | |
| '<div style="font-family: Arial, sans-serif; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background-color: #fafafa;">', | |
| '<h3 style="margin-top: 0; color: #333;">Per-token Perplexity Visualization</h3>', | |
| '<div style="margin-bottom: 15px;">', | |
| '<span style="font-size: 12px; color: #666;">', | |
| 'π’ Low perplexity (confident) β π‘ Medium β π΄ High perplexity (uncertain)', | |
| '</span>', | |
| '</div>', | |
| '<div style="line-height: 2.0;">' | |
| ] | |
| for i, (token, perp, norm_perp) in enumerate(zip(tokens, perplexities, normalized_perplexities)): | |
| # Skip empty tokens | |
| if not token.strip(): | |
| continue | |
| # Clean token for display | |
| clean_token = token.replace("</w>", "").replace("##", "").strip() | |
| if not clean_token: | |
| continue | |
| # Add space before token if needed | |
| if i > 0 and not clean_token[0] in ".,!?;:": | |
| html_parts.append(" ") | |
| # Get color thresholds from configuration | |
| low_thresh = VIZ_SETTINGS.get("thresholds", {}).get("low_threshold", 0.3) | |
| high_thresh = VIZ_SETTINGS.get("thresholds", {}).get("high_threshold", 0.7) | |
| # Get colors from configuration | |
| low_color = VIZ_SETTINGS["color_scheme"]["low_perplexity"] | |
| med_color = VIZ_SETTINGS["color_scheme"]["medium_perplexity"] | |
| high_color = VIZ_SETTINGS["color_scheme"]["high_perplexity"] | |
| # Map perplexity to color using configuration | |
| if norm_perp < low_thresh: # Low perplexity - green | |
| # Interpolate between green and yellow | |
| factor = norm_perp / low_thresh | |
| red = int(low_color["r"] + factor * (med_color["r"] - low_color["r"])) | |
| green = int(low_color["g"] + factor * (med_color["g"] - low_color["g"])) | |
| blue = int(low_color["b"] + factor * (med_color["b"] - low_color["b"])) | |
| elif norm_perp < high_thresh: # Medium perplexity - yellow/orange | |
| # Interpolate between yellow and red | |
| factor = (norm_perp - low_thresh) / (high_thresh - low_thresh) | |
| red = int(med_color["r"] + factor * (high_color["r"] - med_color["r"])) | |
| green = int(med_color["g"] + factor * (high_color["g"] - med_color["g"])) | |
| blue = int(med_color["b"] + factor * (high_color["b"] - med_color["b"])) | |
| else: # High perplexity - red | |
| # Use high perplexity color, potentially darker for very high values | |
| factor = min((norm_perp - high_thresh) / (1.0 - high_thresh), 1.0) | |
| darken = 0.8 - (factor * 0.3) # Darken by up to 30% | |
| red = int(high_color["r"] * darken) | |
| green = int(high_color["g"] * darken) | |
| blue = int(high_color["b"] * darken) | |
| tooltip_text = f"Perplexity: {perp:.3f} (normalized: {norm_perp:.3f})" | |
| # Clamp values | |
| red = max(0, min(255, red)) | |
| green = max(0, min(255, green)) | |
| blue = max(0, min(255, blue)) | |
| # Get alpha values from configuration | |
| bg_alpha = VIZ_SETTINGS["color_scheme"].get("background_alpha", 0.7) | |
| border_alpha = VIZ_SETTINGS["color_scheme"].get("border_alpha", 0.9) | |
| # Create colored span with tooltip | |
| html_parts.append( | |
| f'<span style="' | |
| f'background-color: rgba({red}, {green}, {blue}, {bg_alpha}); ' | |
| f'color: #000; ' | |
| f'padding: 2px 4px; ' | |
| f'margin: 1px; ' | |
| f'border-radius: 3px; ' | |
| f'border: 1px solid rgba({red}, {green}, {blue}, {border_alpha}); ' | |
| f'font-weight: 500; ' | |
| f'cursor: help; ' | |
| f'display: inline-block;' | |
| f'" title="{tooltip_text}">{clean_token}</span>' | |
| ) | |
| html_parts.extend([ | |
| '</div>', | |
| '<div style="margin-top: 15px; font-size: 12px; color: #666;">', | |
| f'Max perplexity in visualization: {max_perplexity:.2f} | ', | |
| f'Total tokens: {len(tokens)}', | |
| '</div>', | |
| '</div>' | |
| ]) | |
| return "".join(html_parts) | |
| def process_text(text, model_name, model_type): | |
| """Main processing function""" | |
| if not text.strip(): | |
| return ERROR_MESSAGES["empty_text"], "", pd.DataFrame() | |
| try: | |
| # Load model and tokenizer | |
| model, tokenizer = load_model_and_tokenizer(model_name, model_type) | |
| # Calculate perplexity | |
| if model_type == "decoder": | |
| avg_perplexity, tokens, token_perplexities = calculate_decoder_perplexity( | |
| text, model, tokenizer | |
| ) | |
| else: # encoder | |
| avg_perplexity, tokens, token_perplexities = calculate_encoder_perplexity( | |
| text, model, tokenizer | |
| ) | |
| # Create visualization | |
| viz_html = create_visualization(tokens, token_perplexities) | |
| # Create summary | |
| summary = f""" | |
| ### Analysis Results | |
| **Model:** `{model_name}` | |
| **Model Type:** {model_type.title()} | |
| **Average Perplexity:** {avg_perplexity:.4f} | |
| **Number of Tokens:** {len(tokens)} | |
| """ | |
| # Create detailed results table | |
| df = pd.DataFrame({ | |
| 'Token': tokens, | |
| 'Perplexity': [f"{p:.4f}" for p in token_perplexities] | |
| }) | |
| return summary, viz_html, df | |
| except Exception as e: | |
| error_msg = ERROR_MESSAGES["processing_error"].format(error=str(e)) | |
| return error_msg, "", pd.DataFrame() | |
| # Create Gradio interface | |
| with gr.Blocks(title=UI_SETTINGS["title"], theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"# {UI_SETTINGS['title']}") | |
| gr.Markdown(UI_SETTINGS["description"]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter the text you want to analyze...", | |
| lines=6, | |
| max_lines=10 | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| label="Model Name", | |
| choices=DEFAULT_MODELS["decoder"] + DEFAULT_MODELS["encoder"], | |
| value="gpt2", | |
| allow_custom_value=True, | |
| info="Select a model or enter a custom HuggingFace model name" | |
| ) | |
| model_type = gr.Radio( | |
| label="Model Type", | |
| choices=["decoder", "encoder"], | |
| value="decoder", | |
| info="Decoder for causal LM, Encoder for masked LM" | |
| ) | |
| analyze_btn = gr.Button("π Analyze Perplexity", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| summary_output = gr.Markdown(label="Summary") | |
| viz_output = gr.HTML(label="Perplexity Visualization") | |
| # Full-width table | |
| with gr.Row(): | |
| table_output = gr.Dataframe( | |
| label="Detailed Token Results", | |
| interactive=False, | |
| wrap=True | |
| ) | |
| # Update model dropdown based on type selection | |
| def update_model_choices(model_type): | |
| return gr.update(choices=DEFAULT_MODELS[model_type], value=DEFAULT_MODELS[model_type][0]) | |
| model_type.change( | |
| fn=update_model_choices, | |
| inputs=[model_type], | |
| outputs=[model_name] | |
| ) | |
| # Set up the analysis function | |
| analyze_btn.click( | |
| fn=process_text, | |
| inputs=[text_input, model_name, model_type], | |
| outputs=[summary_output, viz_output, table_output] | |
| ) | |
| # Add examples | |
| with gr.Accordion("π Example Texts", open=False): | |
| examples_data = [ | |
| [ex["text"], ex["model"], ex["type"]] | |
| for ex in UI_SETTINGS["examples"] | |
| ] | |
| gr.Examples( | |
| examples=examples_data, | |
| inputs=[text_input, model_name, model_type], | |
| outputs=[summary_output, viz_output, table_output], | |
| fn=process_text, | |
| cache_examples=False, | |
| label="Click on an example to try it out:" | |
| ) | |
| # Add footer with information | |
| gr.Markdown(""" | |
| --- | |
| ### π How it works: | |
| - **Decoder Models** (GPT, etc.): Calculate true perplexity by measuring how well the model predicts the next token | |
| - **Encoder Models** (BERT, etc.): Calculate pseudo-perplexity using masked language modeling (MLM) | |
| - **Color Coding**: Red = High perplexity (uncertain), Green = Low perplexity (confident) | |
| ### β οΈ Notes: | |
| - First model load may take some time | |
| - Models are cached after first use | |
| - Very long texts are truncated to 512 tokens | |
| - GPU acceleration is used when available | |
| - All tokens are analyzed in a single pass for accurate results | |
| """) | |
| if __name__ == "__main__": | |
| try: | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=False | |
| ) | |
| except Exception as e: | |
| print(f"β Failed to launch app: {e}") | |
| print("π‘ Try running with: python run.py") | |
| # Fallback to basic launch | |
| try: | |
| demo.launch() | |
| except Exception as fallback_error: | |
| print(f"β Fallback launch also failed: {fallback_error}") | |
| print("π‘ Try updating Gradio: pip install --upgrade gradio") | |