Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test script for PerplexityViewer app | |
| """ | |
| import sys | |
| import os | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM | |
| # Add the current directory to the path so we can import the app | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| try: | |
| from app import ( | |
| load_model_and_tokenizer, | |
| calculate_decoder_perplexity, | |
| calculate_encoder_perplexity, | |
| create_visualization, | |
| process_text | |
| ) | |
| from config import DEFAULT_MODELS, PROCESSING_SETTINGS | |
| except ImportError as e: | |
| print(f"Error importing app modules: {e}") | |
| sys.exit(1) | |
| def test_model_loading(): | |
| """Test model and tokenizer loading""" | |
| print("Testing model loading...") | |
| # Test decoder model | |
| try: | |
| model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") | |
| print("β Decoder model (distilgpt2) loaded successfully") | |
| assert model is not None | |
| assert tokenizer is not None | |
| except Exception as e: | |
| print(f"β Failed to load decoder model: {e}") | |
| return False | |
| # Test encoder model | |
| try: | |
| model, tokenizer = load_model_and_tokenizer("distilbert-base-uncased", "encoder") | |
| print("β Encoder model (distilbert-base-uncased) loaded successfully") | |
| assert model is not None | |
| assert tokenizer is not None | |
| except Exception as e: | |
| print(f"β Failed to load encoder model: {e}") | |
| return False | |
| return True | |
| def test_decoder_perplexity(): | |
| """Test decoder perplexity calculation""" | |
| print("\nTesting decoder perplexity calculation...") | |
| try: | |
| model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") | |
| text = "The quick brown fox jumps over the lazy dog." | |
| avg_perp, tokens, token_perps = calculate_decoder_perplexity(text, model, tokenizer, iterations=1) | |
| print(f"β Average perplexity: {avg_perp:.4f}") | |
| print(f"β Number of tokens: {len(tokens)}") | |
| print(f"β Token perplexities shape: {token_perps.shape}") | |
| assert avg_perp > 0 | |
| assert len(tokens) > 0 | |
| assert len(token_perps) == len(tokens) | |
| assert all(p > 0 for p in token_perps) | |
| return True | |
| except Exception as e: | |
| print(f"β Decoder perplexity test failed: {e}") | |
| return False | |
| def test_encoder_perplexity(): | |
| """Test encoder perplexity calculation""" | |
| print("\nTesting encoder perplexity calculation...") | |
| try: | |
| model, tokenizer = load_model_and_tokenizer("distilbert-base-uncased", "encoder") | |
| text = "The capital of France is Paris." | |
| avg_perp, tokens, token_perps = calculate_encoder_perplexity( | |
| text, model, tokenizer, mlm_probability=0.15, iterations=1 | |
| ) | |
| print(f"β Average pseudo-perplexity: {avg_perp:.4f}") | |
| print(f"β Number of tokens: {len(tokens)}") | |
| print(f"β Token perplexities shape: {token_perps.shape}") | |
| assert avg_perp > 0 | |
| assert len(tokens) > 0 | |
| assert len(token_perps) == len(tokens) | |
| assert all(p > 0 for p in token_perps) | |
| return True | |
| except Exception as e: | |
| print(f"β Encoder perplexity test failed: {e}") | |
| return False | |
| def test_visualization(): | |
| """Test visualization creation""" | |
| print("\nTesting visualization creation...") | |
| try: | |
| # Create dummy data | |
| tokens = ["The", "quick", "brown", "fox", "jumps"] | |
| perplexities = np.array([2.5, 1.8, 3.2, 4.1, 2.9]) | |
| html = create_visualization(tokens, perplexities) | |
| print("β Visualization HTML generated") | |
| assert isinstance(html, str) | |
| assert len(html) > 0 | |
| assert "ent" in html.lower() # displaCy entity visualization | |
| return True | |
| except Exception as e: | |
| print(f"β Visualization test failed: {e}") | |
| return False | |
| def test_edge_cases(): | |
| """Test edge cases and error handling""" | |
| print("\nTesting edge cases...") | |
| # Test empty text | |
| try: | |
| summary, viz, table = process_text("", "distilgpt2", "decoder", 1, 0.15) | |
| assert "enter some text" in summary.lower() | |
| print("β Empty text handled correctly") | |
| except Exception as e: | |
| print(f"β Empty text test failed: {e}") | |
| return False | |
| # Test very short text | |
| try: | |
| model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") | |
| text = "Hi" | |
| avg_perp, tokens, token_perps = calculate_decoder_perplexity(text, model, tokenizer, iterations=1) | |
| print(f"β Short text handled: {len(tokens)} tokens") | |
| except Exception as e: | |
| print(f"β Short text error handled correctly: {e}") | |
| # Test long text (should be truncated) | |
| try: | |
| long_text = " ".join(["word"] * 600) # More than max_length | |
| model, tokenizer = load_model_and_tokenizer("distilgpt2", "decoder") | |
| avg_perp, tokens, token_perps = calculate_decoder_perplexity(long_text, model, tokenizer, iterations=1) | |
| print(f"β Long text truncated to {len(tokens)} tokens") | |
| assert len(tokens) <= 512 # Should be truncated | |
| except Exception as e: | |
| print(f"β Long text test failed: {e}") | |
| return False | |
| return True | |
| def test_process_text_integration(): | |
| """Test the main process_text function""" | |
| print("\nTesting process_text integration...") | |
| test_cases = [ | |
| { | |
| "text": "The quick brown fox jumps over the lazy dog.", | |
| "model": "distilgpt2", | |
| "type": "decoder", | |
| "iterations": 1, | |
| "mlm_prob": 0.15 | |
| }, | |
| { | |
| "text": "Machine learning is a subset of artificial intelligence.", | |
| "model": "distilbert-base-uncased", | |
| "type": "encoder", | |
| "iterations": 1, | |
| "mlm_prob": 0.2 | |
| } | |
| ] | |
| for i, case in enumerate(test_cases): | |
| try: | |
| summary, viz_html, df = process_text( | |
| case["text"], | |
| case["model"], | |
| case["type"], | |
| case["iterations"], | |
| case["mlm_prob"] | |
| ) | |
| print(f"β Test case {i+1} ({case['type']}) processed successfully") | |
| assert "Analysis Results" in summary | |
| assert len(viz_html) > 0 | |
| assert len(df) > 0 | |
| except Exception as e: | |
| print(f"β Test case {i+1} failed: {e}") | |
| return False | |
| return True | |
| def test_configuration(): | |
| """Test configuration loading""" | |
| print("\nTesting configuration...") | |
| try: | |
| assert "decoder" in DEFAULT_MODELS | |
| assert "encoder" in DEFAULT_MODELS | |
| assert len(DEFAULT_MODELS["decoder"]) > 0 | |
| assert len(DEFAULT_MODELS["encoder"]) > 0 | |
| assert PROCESSING_SETTINGS["default_iterations"] >= 1 | |
| print("β Configuration loaded correctly") | |
| return True | |
| except Exception as e: | |
| print(f"β Configuration test failed: {e}") | |
| return False | |
| def run_all_tests(): | |
| """Run all tests""" | |
| print("="*50) | |
| print("Running PerplexityViewer Tests") | |
| print("="*50) | |
| tests = [ | |
| ("Configuration", test_configuration), | |
| ("Model Loading", test_model_loading), | |
| ("Decoder Perplexity", test_decoder_perplexity), | |
| ("Encoder Perplexity", test_encoder_perplexity), | |
| ("Visualization", test_visualization), | |
| ("Edge Cases", test_edge_cases), | |
| ("Integration", test_process_text_integration) | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for test_name, test_func in tests: | |
| print(f"\n[{test_name}]") | |
| try: | |
| if test_func(): | |
| passed += 1 | |
| print(f"β {test_name} PASSED") | |
| else: | |
| failed += 1 | |
| print(f"β {test_name} FAILED") | |
| except Exception as e: | |
| failed += 1 | |
| print(f"β {test_name} FAILED with exception: {e}") | |
| print("\n" + "="*50) | |
| print(f"Test Results: {passed} passed, {failed} failed") | |
| print("="*50) | |
| return failed == 0 | |
| if __name__ == "__main__": | |
| # Check if PyTorch is available | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name()}") | |
| # Run tests | |
| success = run_all_tests() | |
| if success: | |
| print("\nπ All tests passed! The app should work correctly.") | |
| sys.exit(0) | |
| else: | |
| print("\nβ Some tests failed. Please check the errors above.") | |
| sys.exit(1) | |