Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Demo script for PerplexityViewer - shows core functionality without GUI | |
| """ | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def demo_decoder_perplexity(): | |
| """Demo decoder model perplexity calculation""" | |
| print("="*60) | |
| print("π€ Decoder Model Demo (GPT-2)") | |
| print("="*60) | |
| # Load model | |
| model_name = "distilgpt2" | |
| print(f"Loading {model_name}...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.eval() | |
| # Test texts | |
| test_texts = [ | |
| "The quick brown fox jumps over the lazy dog.", | |
| "Machine learning is revolutionizing artificial intelligence.", | |
| "Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo.", | |
| "The capital of France is Paris." | |
| ] | |
| for i, text in enumerate(test_texts, 1): | |
| print(f"\nπ Text {i}: {text}") | |
| # Tokenize | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| input_ids = inputs.input_ids | |
| # Calculate perplexity | |
| with torch.no_grad(): | |
| outputs = model(input_ids, labels=input_ids) | |
| loss = outputs.loss | |
| perplexity = torch.exp(loss).item() | |
| print(f" π― Perplexity: {perplexity:.2f}") | |
| # Get token-level details | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:]) # Skip first token | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs.logits | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = input_ids[..., 1:].contiguous() | |
| loss_fct = torch.nn.CrossEntropyLoss(reduction='none') | |
| token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| token_perplexities = torch.exp(token_losses).cpu().numpy() | |
| print(" π― Token details:") | |
| for token, pp in zip(tokens[:5], token_perplexities[:5]): # Show first 5 | |
| clean_token = token.replace('Δ ', ' ').replace('##', '') | |
| color = 'π’' if pp < 3 else 'π‘' if pp < 10 else 'π΄' | |
| print(f" {color} '{clean_token}': {pp:.2f}") | |
| if len(tokens) > 5: | |
| print(f" ... and {len(tokens) - 5} more tokens") | |
| def demo_encoder_perplexity(): | |
| """Demo encoder model pseudo-perplexity calculation""" | |
| print("\n" + "="*60) | |
| print("π€ Encoder Model Demo (DistilBERT)") | |
| print("="*60) | |
| # Load model | |
| model_name = "distilbert-base-uncased" | |
| print(f"Loading {model_name}...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| model.eval() | |
| # Test texts | |
| test_texts = [ | |
| "The capital of France is Paris.", | |
| "Python is a programming language.", | |
| "The weather today is beautiful.", | |
| "Machine learning requires large datasets." | |
| ] | |
| mlm_probability = 0.15 | |
| for i, text in enumerate(test_texts, 1): | |
| print(f"\nπ Text {i}: {text}") | |
| # Tokenize | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| input_ids = inputs.input_ids | |
| # Create masked version | |
| masked_input_ids = input_ids.clone() | |
| original_tokens = input_ids.clone() | |
| # Randomly mask tokens (excluding special tokens) | |
| seq_length = input_ids.size(1) | |
| mask_indices = [] | |
| special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} | |
| for j in range(seq_length): | |
| if input_ids[0, j].item() not in special_token_ids: | |
| if torch.rand(1).item() < mlm_probability: | |
| mask_indices.append(j) | |
| masked_input_ids[0, j] = tokenizer.mask_token_id | |
| if not mask_indices: # Ensure at least one token is masked | |
| non_special_indices = [j for j in range(seq_length) if input_ids[0, j].item() not in special_token_ids] | |
| if non_special_indices: | |
| mask_idx = torch.randint(0, len(non_special_indices), (1,)).item() | |
| mask_indices = [non_special_indices[mask_idx]] | |
| masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id | |
| # Calculate pseudo-perplexity | |
| with torch.no_grad(): | |
| outputs = model(masked_input_ids) | |
| predictions = outputs.logits | |
| masked_token_losses = [] | |
| for idx in mask_indices: | |
| target_id = original_tokens[0, idx] | |
| pred_scores = predictions[0, idx] | |
| prob = torch.softmax(pred_scores, dim=-1)[target_id] | |
| loss = -torch.log(prob + 1e-10) | |
| masked_token_losses.append(loss.item()) | |
| if masked_token_losses: | |
| avg_loss = np.mean(masked_token_losses) | |
| pseudo_perplexity = np.exp(avg_loss) | |
| else: | |
| pseudo_perplexity = float('inf') | |
| print(f" π― Pseudo-perplexity: {pseudo_perplexity:.2f}") | |
| print(f" π Masked {len(mask_indices)} tokens") | |
| # Show some token-level pseudo-perplexities | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| print(" π― Sample token pseudo-perplexities:") | |
| with torch.no_grad(): | |
| sample_indices = list(range(1, min(6, len(tokens)-1))) # Skip [CLS] and [SEP] | |
| for idx in sample_indices: | |
| if input_ids[0, idx].item() not in special_token_ids: | |
| masked_input = input_ids.clone() | |
| original_token_id = input_ids[0, idx] | |
| masked_input[0, idx] = tokenizer.mask_token_id | |
| outputs = model(masked_input) | |
| predictions = outputs.logits[0, idx] | |
| prob = torch.softmax(predictions, dim=-1)[original_token_id] | |
| token_pseudo_perplexity = 1.0 / (prob.item() + 1e-10) | |
| clean_token = tokens[idx].replace('##', '') | |
| color = 'π’' if token_pseudo_perplexity < 5 else 'π‘' if token_pseudo_perplexity < 20 else 'π΄' | |
| print(f" {color} '{clean_token}': {token_pseudo_perplexity:.2f}") | |
| def demo_comparison(): | |
| """Compare perplexity across different model types""" | |
| print("\n" + "="*60) | |
| print("π¬ Model Comparison Demo") | |
| print("="*60) | |
| test_text = "The quick brown fox jumps over the lazy dog." | |
| print(f"π Comparing models on: {test_text}") | |
| models_to_test = [ | |
| ("distilgpt2", "decoder"), | |
| ("distilbert-base-uncased", "encoder") | |
| ] | |
| results = [] | |
| for model_name, model_type in models_to_test: | |
| print(f"\nπ€ Testing {model_name} ({model_type})...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if model_type == "decoder": | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| else: | |
| model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| model.eval() | |
| inputs = tokenizer(test_text, return_tensors="pt", truncation=True, max_length=512) | |
| input_ids = inputs.input_ids | |
| if model_type == "decoder": | |
| with torch.no_grad(): | |
| outputs = model(input_ids, labels=input_ids) | |
| loss = outputs.loss | |
| perplexity = torch.exp(loss).item() | |
| else: # encoder | |
| # Quick pseudo-perplexity calculation | |
| masked_input_ids = input_ids.clone() | |
| seq_length = input_ids.size(1) | |
| # Mask middle token | |
| if seq_length > 2: | |
| middle_idx = seq_length // 2 | |
| masked_input_ids[0, middle_idx] = tokenizer.mask_token_id | |
| with torch.no_grad(): | |
| outputs = model(masked_input_ids) | |
| predictions = outputs.logits[0, middle_idx] | |
| prob = torch.softmax(predictions, dim=-1)[input_ids[0, middle_idx]] | |
| perplexity = 1.0 / (prob.item() + 1e-10) | |
| else: | |
| perplexity = float('inf') | |
| results.append((model_name, model_type, perplexity)) | |
| print(f" β Perplexity: {perplexity:.2f}") | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| results.append((model_name, model_type, float('inf'))) | |
| print(f"\nπ Summary for '{test_text}':") | |
| for model_name, model_type, perplexity in results: | |
| if perplexity != float('inf'): | |
| confidence = "High" if perplexity < 5 else "Medium" if perplexity < 15 else "Low" | |
| print(f" β’ {model_name} ({model_type}): {perplexity:.2f} - {confidence} confidence") | |
| else: | |
| print(f" β’ {model_name} ({model_type}): Failed") | |
| def main(): | |
| """Run all demos""" | |
| print("π PerplexityViewer Core Functionality Demo") | |
| print("This demo shows how perplexity calculation works under the hood") | |
| try: | |
| demo_decoder_perplexity() | |
| demo_encoder_perplexity() | |
| demo_comparison() | |
| print("\n" + "="*60) | |
| print("π Demo completed successfully!") | |
| print("π‘ To try the interactive web interface, run: python run.py") | |
| print("="*60) | |
| except KeyboardInterrupt: | |
| print("\nπ Demo interrupted by user") | |
| except Exception as e: | |
| print(f"\nβ Demo failed with error: {e}") | |
| print("π‘ Make sure you have installed all dependencies: pip install -r requirements.txt") | |
| if __name__ == "__main__": | |
| main() | |