""" Test custom collate_batch function for variable-length sequences """ print("=" * 60) print("Testing collate_batch for Variable-Length Sequences") print("=" * 60) try: import torch from trainer import collate_batch print("\nāœ… Step 1: Imports successful") # Create mock batch with variable lengths print("\nšŸ”§ Step 2: Creating mock batch with variable lengths...") batch = [ { 'input_ids': torch.tensor([101, 2023, 2003, 102]), # Length 4 'attention_mask': torch.tensor([1, 1, 1, 1]), 'risk_label': torch.tensor(0), 'severity_score': torch.tensor(5.5), 'importance_score': torch.tensor(3.2) }, { 'input_ids': torch.tensor([101, 1996, 2172, 3325, 2003, 102]), # Length 6 'attention_mask': torch.tensor([1, 1, 1, 1, 1, 1]), 'risk_label': torch.tensor(1), 'severity_score': torch.tensor(7.2), 'importance_score': torch.tensor(4.8) }, { 'input_ids': torch.tensor([101, 102]), # Length 2 'attention_mask': torch.tensor([1, 1]), 'risk_label': torch.tensor(2), 'severity_score': torch.tensor(3.1), 'importance_score': torch.tensor(2.5) } ] print(f" āœ… Created batch with 3 samples") print(f" Sample 0: length {len(batch[0]['input_ids'])}") print(f" Sample 1: length {len(batch[1]['input_ids'])}") print(f" Sample 2: length {len(batch[2]['input_ids'])}") # Test collate_batch print("\nšŸ“¦ Step 3: Testing collate_batch()...") collated = collate_batch(batch) print(f" āœ… Collate successful!") print(f"\n šŸ“Š Output shapes:") print(f" input_ids: {collated['input_ids'].shape}") print(f" attention_mask: {collated['attention_mask'].shape}") print(f" risk_label: {collated['risk_label'].shape}") print(f" severity_score: {collated['severity_score'].shape}") print(f" importance_score: {collated['importance_score'].shape}") # Verify shapes print("\nšŸ” Step 4: Verifying shapes...") batch_size = 3 max_len = 6 # Longest sequence in batch assert collated['input_ids'].shape == (batch_size, max_len), \ f"Expected ({batch_size}, {max_len}), got {collated['input_ids'].shape}" assert collated['attention_mask'].shape == (batch_size, max_len), \ f"Expected ({batch_size}, {max_len}), got {collated['attention_mask'].shape}" assert collated['risk_label'].shape == (batch_size,), \ f"Expected ({batch_size},), got {collated['risk_label'].shape}" print(" āœ… All shapes correct!") # Verify padding print("\nšŸ” Step 5: Verifying padding...") print(f" Sample 0 (originally length 4):") print(f" input_ids: {collated['input_ids'][0]}") print(f" attention_mask: {collated['attention_mask'][0]}") print(f" Expected: 2 padding tokens at end") # Check that padding was added correctly assert collated['input_ids'][0][-2:].sum() == 0, "Padding tokens should be 0" assert collated['attention_mask'][0][-2:].sum() == 0, "Padding mask should be 0" print(f"\n Sample 2 (originally length 2):") print(f" input_ids: {collated['input_ids'][2]}") print(f" attention_mask: {collated['attention_mask'][2]}") print(f" Expected: 4 padding tokens at end") assert collated['input_ids'][2][-4:].sum() == 0, "Padding tokens should be 0" assert collated['attention_mask'][2][-4:].sum() == 0, "Padding mask should be 0" print(" āœ… Padding is correct!") # Verify labels are preserved print("\nšŸ” Step 6: Verifying labels preserved...") assert collated['risk_label'][0] == 0 assert collated['risk_label'][1] == 1 assert collated['risk_label'][2] == 2 assert torch.allclose(collated['severity_score'][0], torch.tensor(5.5)) assert torch.allclose(collated['severity_score'][1], torch.tensor(7.2)) print(" āœ… All labels preserved correctly!") # Test with DataLoader print("\nšŸš€ Step 7: Testing with actual DataLoader...") from torch.utils.data import Dataset, DataLoader class MockDataset(Dataset): def __len__(self): return 10 def __getitem__(self, idx): # Simulate variable lengths length = torch.randint(5, 15, (1,)).item() return { 'input_ids': torch.randint(0, 1000, (length,)), 'attention_mask': torch.ones(length, dtype=torch.long), 'risk_label': torch.tensor(idx % 3), 'severity_score': torch.tensor(float(idx)), 'importance_score': torch.tensor(float(idx) * 0.5) } dataset = MockDataset() dataloader = DataLoader( dataset, batch_size=4, shuffle=False, collate_fn=collate_batch ) print(f" āœ… Created DataLoader with batch_size=4") # Test iteration print(f"\n Testing iteration over batches...") for batch_idx, batch in enumerate(dataloader): print(f" Batch {batch_idx}: input_ids shape = {batch['input_ids'].shape}") # Verify all tensors in batch have same size batch_size = batch['input_ids'].shape[0] seq_len = batch['input_ids'].shape[1] assert batch['attention_mask'].shape == (batch_size, seq_len) assert batch['risk_label'].shape == (batch_size,) if batch_idx >= 1: # Just test first 2 batches break print(f" āœ… DataLoader iteration successful!") print("\n" + "=" * 60) print("šŸŽ‰ ALL TESTS PASSED!") print("=" * 60) print("\nāœ… collate_batch works correctly") print("āœ… Handles variable-length sequences") print("āœ… Pads to maximum length in batch") print("āœ… Preserves all labels and scores") print("āœ… Compatible with DataLoader") print("\nšŸš€ Ready to run: python3 train.py") except ImportError as e: print(f"\nāŒ Import error: {e}") print(" Make sure torch is installed") exit(1) except AssertionError as e: print(f"\nāŒ Assertion failed: {e}") exit(1) except Exception as e: print(f"\nāŒ Test failed: {e}") import traceback traceback.print_exc() exit(1)