PerplexityViewer / test_app.py
Bram van Es
bla
ef12530
#!/usr/bin/env python3
"""
Test script for PerplexityViewer app
"""
import sys
import os
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
# Add the current directory to the path so we can import the app
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
try:
from app import (
load_model_and_tokenizer,
calculate_decoder_perplexity,
calculate_encoder_perplexity,
create_visualization,
process_text
)
from config import DEFAULT_MODELS, PROCESSING_SETTINGS
except ImportError as e:
print(f"Error importing app modules: {e}")
sys.exit(1)
def test_model_loading():
"""Test model and tokenizer loading"""
print("Testing model loading...")
# Test decoder model
try:
model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder")
print("βœ“ Decoder model (distilgpt2) loaded successfully")
assert model is not None
assert tokenizer is not None
except Exception as e:
print(f"βœ— Failed to load decoder model: {e}")
return False
# Test encoder model
try:
model, tokenizer = load_model_and_tokenizer("distilbert-base-uncased", "encoder")
print("βœ“ Encoder model (distilbert-base-uncased) loaded successfully")
assert model is not None
assert tokenizer is not None
except Exception as e:
print(f"βœ— Failed to load encoder model: {e}")
return False
return True
def test_decoder_perplexity():
"""Test decoder perplexity calculation"""
print("\nTesting decoder perplexity calculation...")
try:
model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder")
text = "The quick brown fox jumps over the lazy dog."
avg_perp, tokens, token_perps = calculate_decoder_perplexity(text, model, tokenizer, iterations=1)
print(f"βœ“ Average perplexity: {avg_perp:.4f}")
print(f"βœ“ Number of tokens: {len(tokens)}")
print(f"βœ“ Token perplexities shape: {token_perps.shape}")
assert avg_perp > 0
assert len(tokens) > 0
assert len(token_perps) == len(tokens)
assert all(p > 0 for p in token_perps)
return True
except Exception as e:
print(f"βœ— Decoder perplexity test failed: {e}")
return False
def test_encoder_perplexity():
"""Test encoder perplexity calculation"""
print("\nTesting encoder perplexity calculation...")
try:
model, tokenizer = load_model_and_tokenizer("distilbert-base-uncased", "encoder")
text = "The capital of France is Paris."
avg_perp, tokens, token_perps = calculate_encoder_perplexity(
text, model, tokenizer, mlm_probability=0.15, iterations=1
)
print(f"βœ“ Average pseudo-perplexity: {avg_perp:.4f}")
print(f"βœ“ Number of tokens: {len(tokens)}")
print(f"βœ“ Token perplexities shape: {token_perps.shape}")
assert avg_perp > 0
assert len(tokens) > 0
assert len(token_perps) == len(tokens)
assert all(p > 0 for p in token_perps)
return True
except Exception as e:
print(f"βœ— Encoder perplexity test failed: {e}")
return False
def test_visualization():
"""Test visualization creation"""
print("\nTesting visualization creation...")
try:
# Create dummy data
tokens = ["The", "quick", "brown", "fox", "jumps"]
perplexities = np.array([2.5, 1.8, 3.2, 4.1, 2.9])
html = create_visualization(tokens, perplexities)
print("βœ“ Visualization HTML generated")
assert isinstance(html, str)
assert len(html) > 0
assert "ent" in html.lower() # displaCy entity visualization
return True
except Exception as e:
print(f"βœ— Visualization test failed: {e}")
return False
def test_edge_cases():
"""Test edge cases and error handling"""
print("\nTesting edge cases...")
# Test empty text
try:
summary, viz, table = process_text("", "distilgpt2", "decoder", 1, 0.15)
assert "enter some text" in summary.lower()
print("βœ“ Empty text handled correctly")
except Exception as e:
print(f"βœ— Empty text test failed: {e}")
return False
# Test very short text
try:
model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder")
text = "Hi"
avg_perp, tokens, token_perps = calculate_decoder_perplexity(text, model, tokenizer, iterations=1)
print(f"βœ“ Short text handled: {len(tokens)} tokens")
except Exception as e:
print(f"βœ“ Short text error handled correctly: {e}")
# Test long text (should be truncated)
try:
long_text = " ".join(["word"] * 600) # More than max_length
model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder")
avg_perp, tokens, token_perps = calculate_decoder_perplexity(long_text, model, tokenizer, iterations=1)
print(f"βœ“ Long text truncated to {len(tokens)} tokens")
assert len(tokens) <= 512 # Should be truncated
except Exception as e:
print(f"βœ— Long text test failed: {e}")
return False
return True
def test_process_text_integration():
"""Test the main process_text function"""
print("\nTesting process_text integration...")
test_cases = [
{
"text": "The quick brown fox jumps over the lazy dog.",
"model": "distilgpt2",
"type": "decoder",
"iterations": 1,
"mlm_prob": 0.15
},
{
"text": "Machine learning is a subset of artificial intelligence.",
"model": "distilbert-base-uncased",
"type": "encoder",
"iterations": 1,
"mlm_prob": 0.2
}
]
for i, case in enumerate(test_cases):
try:
summary, viz_html, df = process_text(
case["text"],
case["model"],
case["type"],
case["iterations"],
case["mlm_prob"]
)
print(f"βœ“ Test case {i+1} ({case['type']}) processed successfully")
assert "Analysis Results" in summary
assert len(viz_html) > 0
assert len(df) > 0
except Exception as e:
print(f"βœ— Test case {i+1} failed: {e}")
return False
return True
def test_configuration():
"""Test configuration loading"""
print("\nTesting configuration...")
try:
assert "decoder" in DEFAULT_MODELS
assert "encoder" in DEFAULT_MODELS
assert len(DEFAULT_MODELS["decoder"]) > 0
assert len(DEFAULT_MODELS["encoder"]) > 0
assert PROCESSING_SETTINGS["default_iterations"] >= 1
print("βœ“ Configuration loaded correctly")
return True
except Exception as e:
print(f"βœ— Configuration test failed: {e}")
return False
def run_all_tests():
"""Run all tests"""
print("="*50)
print("Running PerplexityViewer Tests")
print("="*50)
tests = [
("Configuration", test_configuration),
("Model Loading", test_model_loading),
("Decoder Perplexity", test_decoder_perplexity),
("Encoder Perplexity", test_encoder_perplexity),
("Visualization", test_visualization),
("Edge Cases", test_edge_cases),
("Integration", test_process_text_integration)
]
passed = 0
failed = 0
for test_name, test_func in tests:
print(f"\n[{test_name}]")
try:
if test_func():
passed += 1
print(f"βœ“ {test_name} PASSED")
else:
failed += 1
print(f"βœ— {test_name} FAILED")
except Exception as e:
failed += 1
print(f"βœ— {test_name} FAILED with exception: {e}")
print("\n" + "="*50)
print(f"Test Results: {passed} passed, {failed} failed")
print("="*50)
return failed == 0
if __name__ == "__main__":
# Check if PyTorch is available
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name()}")
# Run tests
success = run_all_tests()
if success:
print("\nπŸŽ‰ All tests passed! The app should work correctly.")
sys.exit(0)
else:
print("\n❌ Some tests failed. Please check the errors above.")
sys.exit(1)