Spaces:
Sleeping
Sleeping
| """ | |
| Test Script to Verify Model Configuration and Data Flow | |
| Run this BEFORE training to ensure everything is correct! | |
| """ | |
| import torch | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add parent to path | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from models_best_old import BestTransformer, TransformerConfig | |
| from utils.data_processing import DataProcessor, collate_fn | |
| from config import Config | |
| def test_special_tokens(): | |
| """Test 1: Verify special token indices match""" | |
| print("=" * 60) | |
| print("TEST 1: Special Token Indices") | |
| print("=" * 60) | |
| # DataProcessor indices | |
| processor = DataProcessor(Config) | |
| print(f"\nDataProcessor:") | |
| print(f" PAD: {processor.pad_idx}") | |
| print(f" UNK: {processor.unk_idx}") | |
| print(f" SOS (BOS): {processor.sos_idx}") | |
| print(f" EOS: {processor.eos_idx}") | |
| # Model config indices | |
| config = TransformerConfig.base() | |
| print(f"\nTransformerConfig:") | |
| print(f" PAD: {config.pad_idx}") | |
| print(f" BOS: {config.bos_idx}") | |
| print(f" EOS: {config.eos_idx}") | |
| # Verify match | |
| assert config.pad_idx == processor.pad_idx, "PAD index mismatch!" | |
| assert config.bos_idx == processor.sos_idx, "BOS/SOS index mismatch!" | |
| assert config.eos_idx == processor.eos_idx, "EOS index mismatch!" | |
| print(f"\n✅ PASS: All special token indices match!") | |
| return True | |
| def test_data_shapes(): | |
| """Test 2: Verify data shapes and masks""" | |
| print("\n" + "=" * 60) | |
| print("TEST 2: Data Shapes and Masks") | |
| print("=" * 60) | |
| # Create dummy batch | |
| batch_size = 4 | |
| src_len = 10 | |
| tgt_len = 12 | |
| # Simulate batch from collate_fn | |
| src = torch.randint(4, 100, (batch_size, src_len)) # Token IDs 4-99 | |
| tgt = torch.randint(4, 100, (batch_size, tgt_len)) | |
| # Set PAD tokens | |
| src[:, -2:] = 0 # Last 2 tokens are PAD | |
| tgt[:, -3:] = 0 # Last 3 tokens are PAD | |
| # Create masks manually (như collate_fn) | |
| src_mask = (src != 0).unsqueeze(1).unsqueeze(2) # [B, 1, 1, S] | |
| tgt_padding = (tgt != 0).unsqueeze(1).unsqueeze(3) # [B, 1, T, 1] | |
| tgt_causal = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0) # [1, 1, T, T] | |
| tgt_mask = tgt_padding & tgt_causal.bool() # [B, 1, T, T] | |
| print(f"\nShapes:") | |
| print(f" src: {src.shape}") | |
| print(f" tgt: {tgt.shape}") | |
| print(f" src_mask: {src_mask.shape}") | |
| print(f" tgt_mask: {tgt_mask.shape}") | |
| # Verify shapes | |
| assert src.shape == (batch_size, src_len) | |
| assert tgt.shape == (batch_size, tgt_len) | |
| assert src_mask.shape == (batch_size, 1, 1, src_len) | |
| assert tgt_mask.shape == (batch_size, 1, tgt_len, tgt_len) | |
| # Test mask types | |
| print(f"\nMask dtypes:") | |
| print(f" src_mask: {src_mask.dtype}") | |
| print(f" tgt_mask: {tgt_mask.dtype}") | |
| # Verify causal mask | |
| assert torch.all(tgt_mask[0, 0].diagonal() == True), "Diagonal should be True" | |
| assert torch.all(tgt_mask[0, 0][0, 1:] == False), "Future tokens should be False" | |
| print(f"\n✅ PASS: All data shapes and masks correct!") | |
| return True | |
| def test_model_forward(): | |
| """Test 3: Model forward pass""" | |
| print("\n" + "=" * 60) | |
| print("TEST 3: Model Forward Pass") | |
| print("=" * 60) | |
| # Create small model | |
| config = TransformerConfig.small() | |
| config.device = 'cpu' | |
| config.pad_idx = 0 | |
| config.bos_idx = 2 | |
| config.eos_idx = 3 | |
| vocab_size = 1000 | |
| model = BestTransformer(vocab_size, vocab_size, config) | |
| model.eval() | |
| print(f"\nModel created:") | |
| print(f" Parameters: {model.count_parameters():,}") | |
| print(f" d_model: {config.d_model}") | |
| print(f" n_layers: {config.n_encoder_layers}") | |
| # Create dummy input | |
| batch_size = 2 | |
| src_len = 10 | |
| tgt_len = 8 | |
| src = torch.randint(4, 100, (batch_size, src_len)) | |
| tgt = torch.randint(4, 100, (batch_size, tgt_len)) | |
| # Create masks | |
| src_mask = (src != 0).unsqueeze(1).unsqueeze(2) | |
| tgt_padding = (tgt != 0).unsqueeze(1).unsqueeze(3) | |
| tgt_causal = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0) | |
| tgt_mask = tgt_padding & tgt_causal.bool() | |
| print(f"\nInput shapes:") | |
| print(f" src: {src.shape}") | |
| print(f" tgt: {tgt.shape}") | |
| print(f" src_mask: {src_mask.shape}") | |
| print(f" tgt_mask: {tgt_mask.shape}") | |
| # Forward pass | |
| with torch.no_grad(): | |
| logits = model(src, tgt, src_mask, tgt_mask) | |
| print(f"\nOutput shape:") | |
| print(f" logits: {logits.shape}") | |
| # Verify output shape | |
| expected_shape = (batch_size, tgt_len, vocab_size) | |
| assert logits.shape == expected_shape, f"Expected {expected_shape}, got {logits.shape}" | |
| # Check logits are valid | |
| assert not torch.isnan(logits).any(), "NaN detected in logits!" | |
| assert not torch.isinf(logits).any(), "Inf detected in logits!" | |
| print(f"\n✅ PASS: Model forward pass successful!") | |
| return True | |
| def test_training_step(): | |
| """Test 4: Training step (forward + loss + backward)""" | |
| print("\n" + "=" * 60) | |
| print("TEST 4: Training Step") | |
| print("=" * 60) | |
| # Create model | |
| config = TransformerConfig.small() | |
| config.device = 'cpu' | |
| config.pad_idx = 0 | |
| config.bos_idx = 2 | |
| config.eos_idx = 3 | |
| vocab_size = 1000 | |
| model = BestTransformer(vocab_size, vocab_size, config) | |
| model.train() | |
| # Create optimizer | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
| # Create loss | |
| from models_best_old import LabelSmoothingLoss | |
| criterion = LabelSmoothingLoss( | |
| num_classes=vocab_size, | |
| smoothing=0.1, | |
| ignore_index=0 # PAD | |
| ) | |
| # Create dummy batch | |
| batch_size = 4 | |
| src_len = 12 | |
| tgt_len = 10 | |
| src = torch.randint(4, 100, (batch_size, src_len)) | |
| tgt = torch.randint(4, 100, (batch_size, tgt_len)) | |
| src[:, -2:] = 0 # PAD | |
| tgt[:, -2:] = 0 # PAD | |
| # Create masks | |
| src_mask = (src != 0).unsqueeze(1).unsqueeze(2) | |
| tgt_padding = (tgt != 0).unsqueeze(1).unsqueeze(3) | |
| tgt_causal = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0) | |
| tgt_mask = tgt_padding & tgt_causal.bool() | |
| # Training step | |
| tgt_input = tgt[:, :-1] # Remove last token | |
| tgt_output = tgt[:, 1:] # Remove first token | |
| tgt_mask = tgt_mask[:, :, :-1, :-1] # Adjust mask | |
| print(f"\nTraining shapes:") | |
| print(f" src: {src.shape}") | |
| print(f" tgt_input: {tgt_input.shape}") | |
| print(f" tgt_output: {tgt_output.shape}") | |
| print(f" tgt_mask: {tgt_mask.shape}") | |
| # Forward | |
| logits = model(src, tgt_input, src_mask, tgt_mask) | |
| print(f" logits: {logits.shape}") | |
| # Loss | |
| loss = criterion(logits, tgt_output) | |
| print(f"\nLoss: {loss.item():.4f}") | |
| assert not torch.isnan(loss), "Loss is NaN!" | |
| assert not torch.isinf(loss), "Loss is Inf!" | |
| assert loss.item() > 0, "Loss should be positive!" | |
| # Backward | |
| optimizer.zero_grad() | |
| loss.backward() | |
| # Check gradients | |
| has_grad = False | |
| for name, param in model.named_parameters(): | |
| if param.grad is not None: | |
| has_grad = True | |
| assert not torch.isnan(param.grad).any(), f"NaN gradient in {name}" | |
| assert not torch.isinf(param.grad).any(), f"Inf gradient in {name}" | |
| break | |
| assert has_grad, "No gradients computed!" | |
| # Optimizer step | |
| optimizer.step() | |
| print(f"\n✅ PASS: Training step successful!") | |
| return True | |
| def test_beam_search(): | |
| """Test 5: Beam search inference""" | |
| print("\n" + "=" * 60) | |
| print("TEST 5: Beam Search Inference") | |
| print("=" * 60) | |
| # Create model | |
| config = TransformerConfig.small() | |
| config.device = 'cpu' | |
| config.pad_idx = 0 | |
| config.bos_idx = 2 | |
| config.eos_idx = 3 | |
| config.max_decode_length = 20 | |
| config.beam_size = 4 | |
| vocab_size = 1000 | |
| model = BestTransformer(vocab_size, vocab_size, config) | |
| model.eval() | |
| # Create source sentence | |
| src = torch.randint(4, 100, (1, 15)) # Single sentence | |
| print(f"\nBeam search config:") | |
| print(f" src: {src.shape}") | |
| print(f" beam_size: {config.beam_size}") | |
| print(f" max_len: {config.max_decode_length}") | |
| print(f" BOS: {config.bos_idx}") | |
| print(f" EOS: {config.eos_idx}") | |
| # Translate | |
| with torch.no_grad(): | |
| translation = model.translate_beam( | |
| src, | |
| max_len=config.max_decode_length, | |
| beam_size=config.beam_size, | |
| length_penalty=0.6 | |
| ) | |
| print(f"\nTranslation result:") | |
| print(f" Type: {type(translation)}") | |
| print(f" Shape: {translation.shape if isinstance(translation, torch.Tensor) else len(translation)}") | |
| print(f" First 10 tokens: {translation[:10].tolist() if isinstance(translation, torch.Tensor) else translation[:10]}") | |
| # Verify translation | |
| if isinstance(translation, torch.Tensor): | |
| translation = translation.tolist() | |
| assert len(translation) > 0, "Empty translation!" | |
| assert len(translation) <= config.max_decode_length, "Translation too long!" | |
| # Check if starts with BOS (if included) or just tokens | |
| # Check if ends with EOS | |
| if config.eos_idx in translation: | |
| print(f" ✅ EOS token found at position {translation.index(config.eos_idx)}") | |
| else: | |
| print(f" ⚠️ EOS token not found (may have reached max_len)") | |
| print(f"\n✅ PASS: Beam search inference successful!") | |
| return True | |
| def main(): | |
| """Run all tests""" | |
| print("\n" + "=" * 60) | |
| print("RUNNING PRE-TRAINING VERIFICATION TESTS") | |
| print("=" * 60) | |
| tests = [ | |
| ("Special Token Indices", test_special_tokens), | |
| ("Data Shapes and Masks", test_data_shapes), | |
| ("Model Forward Pass", test_model_forward), | |
| ("Training Step", test_training_step), | |
| ("Beam Search Inference", test_beam_search), | |
| ] | |
| results = [] | |
| for test_name, test_func in tests: | |
| try: | |
| passed = test_func() | |
| results.append((test_name, "PASS" if passed else "FAIL")) | |
| except Exception as e: | |
| print(f"\n❌ FAIL: {str(e)}") | |
| results.append((test_name, f"FAIL: {str(e)}")) | |
| # Summary | |
| print("\n" + "=" * 60) | |
| print("TEST SUMMARY") | |
| print("=" * 60) | |
| for test_name, result in results: | |
| status = "✅" if result == "PASS" else "❌" | |
| print(f"{status} {test_name}: {result}") | |
| all_passed = all(result == "PASS" for _, result in results) | |
| print("\n" + "=" * 60) | |
| if all_passed: | |
| print("🎉 ALL TESTS PASSED! Ready to train!") | |
| else: | |
| print("❌ SOME TESTS FAILED! Fix issues before training!") | |
| print("=" * 60) | |
| return all_passed | |
| if __name__ == '__main__': | |
| success = main() | |
| sys.exit(0 if success else 1) | |