#!/usr/bin/env python3 """ Quick Test Script - Verify Installation and Setup Tests basic functionality without full training """ import sys import os def test_imports(): """Test if all required packages can be imported""" print("๐Ÿ” Testing imports...") errors = [] # Core dependencies packages = [ ('torch', 'PyTorch'), ('transformers', 'Transformers'), ('sklearn', 'scikit-learn'), ('pandas', 'Pandas'), ('numpy', 'NumPy') ] for package, name in packages: try: __import__(package) print(f" โœ… {name}") except ImportError: print(f" โŒ {name} - NOT INSTALLED") errors.append(name) # Optional dependencies optional_packages = [ ('matplotlib', 'Matplotlib'), ('seaborn', 'Seaborn') ] print("\n๐Ÿ“Š Optional packages:") for package, name in optional_packages: try: __import__(package) print(f" โœ… {name}") except ImportError: print(f" โš ๏ธ {name} - Not installed (visualizations will be skipped)") return len(errors) == 0, errors def test_module_imports(): """Test if project modules can be imported""" print("\n๐Ÿ” Testing project modules...") modules = [ 'config', 'data_loader', 'risk_discovery', 'model', 'trainer', 'evaluator', 'utils' ] errors = [] for module in modules: try: __import__(module) print(f" โœ… {module}.py") except Exception as e: print(f" โŒ {module}.py - {str(e)[:50]}") errors.append(module) return len(errors) == 0, errors def test_configuration(): """Test configuration loading""" print("\n๐Ÿ” Testing configuration...") try: from config import LegalBertConfig config = LegalBertConfig() print(f" โœ… Configuration loaded") print(f" - BERT Model: {config.bert_model_name}") print(f" - Batch Size: {config.batch_size}") print(f" - Epochs: {config.num_epochs}") print(f" - Device: {config.device}") print(f" - Data Path: {config.data_path}") # Check if data path exists if os.path.exists(config.data_path): print(f" โœ… CUAD dataset found at {config.data_path}") else: print(f" โš ๏ธ CUAD dataset NOT found at {config.data_path}") print(f" Please download and place it at this location") return True, None except Exception as e: print(f" โŒ Configuration failed: {e}") return False, str(e) def test_model_initialization(): """Test model can be initialized""" print("\n๐Ÿ” Testing model initialization...") try: from config import LegalBertConfig from model import FullyLearningBasedLegalBERT, HierarchicalLegalBERT, LegalBertTokenizer config = LegalBertConfig() # Test tokenizer tokenizer = LegalBertTokenizer(config.bert_model_name) test_text = "This agreement shall be governed by the laws of Delaware." encoded = tokenizer.tokenize_clauses([test_text], max_length=128) print(f" โœ… Tokenizer works") print(f" Input shape: {encoded['input_ids'].shape}") # Test model (without downloading weights for quick test) print(f" โ„น๏ธ Model initialization skipped (use train.py to fully initialize)") return True, None except Exception as e: print(f" โŒ Model initialization failed: {e}") return False, str(e) def test_data_loader(): """Test data loader (if data exists)""" print("\n๐Ÿ” Testing data loader...") try: from config import LegalBertConfig from data_loader import CUADDataLoader config = LegalBertConfig() if not os.path.exists(config.data_path): print(f" โš ๏ธ Data file not found, skipping data loader test") return True, None loader = CUADDataLoader(config.data_path) print(f" โœ… Data loader initialized") # Try to load a small sample print(f" โ„น๏ธ Full data loading skipped (use train.py for full load)") return True, None except Exception as e: print(f" โŒ Data loader failed: {e}") return False, str(e) def main(): """Run all tests""" print("=" * 80) print("๐Ÿงช LEGAL-BERT PROJECT - QUICK TEST") print("=" * 80) all_passed = True # Test imports passed, errors = test_imports() if not passed: print(f"\nโŒ Missing required packages: {', '.join(errors)}") print(f" Install with: pip install -r requirements.txt") all_passed = False # Test modules passed, errors = test_module_imports() if not passed: print(f"\nโŒ Module import errors: {', '.join(errors)}") all_passed = False # Test configuration passed, error = test_configuration() if not passed: all_passed = False # Test model initialization passed, error = test_model_initialization() if not passed: all_passed = False # Test data loader passed, error = test_data_loader() if not passed: all_passed = False # Summary print("\n" + "=" * 80) if all_passed: print("โœ… ALL TESTS PASSED!") print("=" * 80) print("\n๐Ÿš€ Ready to train! Run: python train.py") else: print("โŒ SOME TESTS FAILED") print("=" * 80) print("\nโš ๏ธ Please fix the issues above before training") return 0 if all_passed else 1 if __name__ == "__main__": sys.exit(main())