#!/usr/bin/env python3 """ Test script for PerplexityViewer app """ import sys import os import torch import numpy as np from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM # Add the current directory to the path so we can import the app sys.path.append(os.path.dirname(os.path.abspath(__file__))) try: from app import ( load_model_and_tokenizer, calculate_decoder_perplexity, calculate_encoder_perplexity, create_visualization, process_text ) from config import DEFAULT_MODELS, PROCESSING_SETTINGS except ImportError as e: print(f"Error importing app modules: {e}") sys.exit(1) def test_model_loading(): """Test model and tokenizer loading""" print("Testing model loading...") # Test decoder model try: model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") print("✓ Decoder model (distilgpt2) loaded successfully") assert model is not None assert tokenizer is not None except Exception as e: print(f"✗ Failed to load decoder model: {e}") return False # Test encoder model try: model, tokenizer = load_model_and_tokenizer("distilbert-base-uncased", "encoder") print("✓ Encoder model (distilbert-base-uncased) loaded successfully") assert model is not None assert tokenizer is not None except Exception as e: print(f"✗ Failed to load encoder model: {e}") return False return True def test_decoder_perplexity(): """Test decoder perplexity calculation""" print("\nTesting decoder perplexity calculation...") try: model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") text = "The quick brown fox jumps over the lazy dog." avg_perp, tokens, token_perps = calculate_decoder_perplexity(text, model, tokenizer, iterations=1) print(f"✓ Average perplexity: {avg_perp:.4f}") print(f"✓ Number of tokens: {len(tokens)}") print(f"✓ Token perplexities shape: {token_perps.shape}") assert avg_perp > 0 assert len(tokens) > 0 assert len(token_perps) == len(tokens) assert all(p > 0 for p in token_perps) return True except Exception as e: print(f"✗ Decoder perplexity test failed: {e}") return False def test_encoder_perplexity(): """Test encoder perplexity calculation""" print("\nTesting encoder perplexity calculation...") try: model, tokenizer = load_model_and_tokenizer("distilbert-base-uncased", "encoder") text = "The capital of France is Paris." avg_perp, tokens, token_perps = calculate_encoder_perplexity( text, model, tokenizer, mlm_probability=0.15, iterations=1 ) print(f"✓ Average pseudo-perplexity: {avg_perp:.4f}") print(f"✓ Number of tokens: {len(tokens)}") print(f"✓ Token perplexities shape: {token_perps.shape}") assert avg_perp > 0 assert len(tokens) > 0 assert len(token_perps) == len(tokens) assert all(p > 0 for p in token_perps) return True except Exception as e: print(f"✗ Encoder perplexity test failed: {e}") return False def test_visualization(): """Test visualization creation""" print("\nTesting visualization creation...") try: # Create dummy data tokens = ["The", "quick", "brown", "fox", "jumps"] perplexities = np.array([2.5, 1.8, 3.2, 4.1, 2.9]) html = create_visualization(tokens, perplexities) print("✓ Visualization HTML generated") assert isinstance(html, str) assert len(html) > 0 assert "ent" in html.lower() # displaCy entity visualization return True except Exception as e: print(f"✗ Visualization test failed: {e}") return False def test_edge_cases(): """Test edge cases and error handling""" print("\nTesting edge cases...") # Test empty text try: summary, viz, table = process_text("", "distilgpt2", "decoder", 1, 0.15) assert "enter some text" in summary.lower() print("✓ Empty text handled correctly") except Exception as e: print(f"✗ Empty text test failed: {e}") return False # Test very short text try: model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") text = "Hi" avg_perp, tokens, token_perps = calculate_decoder_perplexity(text, model, tokenizer, iterations=1) print(f"✓ Short text handled: {len(tokens)} tokens") except Exception as e: print(f"✓ Short text error handled correctly: {e}") # Test long text (should be truncated) try: long_text = " ".join(["word"] * 600) # More than max_length model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") avg_perp, tokens, token_perps = calculate_decoder_perplexity(long_text, model, tokenizer, iterations=1) print(f"✓ Long text truncated to {len(tokens)} tokens") assert len(tokens) <= 512 # Should be truncated except Exception as e: print(f"✗ Long text test failed: {e}") return False return True def test_process_text_integration(): """Test the main process_text function""" print("\nTesting process_text integration...") test_cases = [ { "text": "The quick brown fox jumps over the lazy dog.", "model": "distilgpt2", "type": "decoder", "iterations": 1, "mlm_prob": 0.15 }, { "text": "Machine learning is a subset of artificial intelligence.", "model": "distilbert-base-uncased", "type": "encoder", "iterations": 1, "mlm_prob": 0.2 } ] for i, case in enumerate(test_cases): try: summary, viz_html, df = process_text( case["text"], case["model"], case["type"], case["iterations"], case["mlm_prob"] ) print(f"✓ Test case {i+1} ({case['type']}) processed successfully") assert "Analysis Results" in summary assert len(viz_html) > 0 assert len(df) > 0 except Exception as e: print(f"✗ Test case {i+1} failed: {e}") return False return True def test_configuration(): """Test configuration loading""" print("\nTesting configuration...") try: assert "decoder" in DEFAULT_MODELS assert "encoder" in DEFAULT_MODELS assert len(DEFAULT_MODELS["decoder"]) > 0 assert len(DEFAULT_MODELS["encoder"]) > 0 assert PROCESSING_SETTINGS["default_iterations"] >= 1 print("✓ Configuration loaded correctly") return True except Exception as e: print(f"✗ Configuration test failed: {e}") return False def run_all_tests(): """Run all tests""" print("="*50) print("Running PerplexityViewer Tests") print("="*50) tests = [ ("Configuration", test_configuration), ("Model Loading", test_model_loading), ("Decoder Perplexity", test_decoder_perplexity), ("Encoder Perplexity", test_encoder_perplexity), ("Visualization", test_visualization), ("Edge Cases", test_edge_cases), ("Integration", test_process_text_integration) ] passed = 0 failed = 0 for test_name, test_func in tests: print(f"\n[{test_name}]") try: if test_func(): passed += 1 print(f"✓ {test_name} PASSED") else: failed += 1 print(f"✗ {test_name} FAILED") except Exception as e: failed += 1 print(f"✗ {test_name} FAILED with exception: {e}") print("\n" + "="*50) print(f"Test Results: {passed} passed, {failed} failed") print("="*50) return failed == 0 if __name__ == "__main__": # Check if PyTorch is available print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name()}") # Run tests success = run_all_tests() if success: print("\n🎉 All tests passed! The app should work correctly.") sys.exit(0) else: print("\n❌ Some tests failed. Please check the errors above.") sys.exit(1)