code2-repo / test_collate_batch.py
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified
"""
Test custom collate_batch function for variable-length sequences
"""
print("=" * 60)
print("Testing collate_batch for Variable-Length Sequences")
print("=" * 60)
try:
import torch
from trainer import collate_batch
print("\nβœ… Step 1: Imports successful")
# Create mock batch with variable lengths
print("\nπŸ”§ Step 2: Creating mock batch with variable lengths...")
batch = [
{
'input_ids': torch.tensor([101, 2023, 2003, 102]), # Length 4
'attention_mask': torch.tensor([1, 1, 1, 1]),
'risk_label': torch.tensor(0),
'severity_score': torch.tensor(5.5),
'importance_score': torch.tensor(3.2)
},
{
'input_ids': torch.tensor([101, 1996, 2172, 3325, 2003, 102]), # Length 6
'attention_mask': torch.tensor([1, 1, 1, 1, 1, 1]),
'risk_label': torch.tensor(1),
'severity_score': torch.tensor(7.2),
'importance_score': torch.tensor(4.8)
},
{
'input_ids': torch.tensor([101, 102]), # Length 2
'attention_mask': torch.tensor([1, 1]),
'risk_label': torch.tensor(2),
'severity_score': torch.tensor(3.1),
'importance_score': torch.tensor(2.5)
}
]
print(f" βœ… Created batch with 3 samples")
print(f" Sample 0: length {len(batch[0]['input_ids'])}")
print(f" Sample 1: length {len(batch[1]['input_ids'])}")
print(f" Sample 2: length {len(batch[2]['input_ids'])}")
# Test collate_batch
print("\nπŸ“¦ Step 3: Testing collate_batch()...")
collated = collate_batch(batch)
print(f" βœ… Collate successful!")
print(f"\n πŸ“Š Output shapes:")
print(f" input_ids: {collated['input_ids'].shape}")
print(f" attention_mask: {collated['attention_mask'].shape}")
print(f" risk_label: {collated['risk_label'].shape}")
print(f" severity_score: {collated['severity_score'].shape}")
print(f" importance_score: {collated['importance_score'].shape}")
# Verify shapes
print("\nπŸ” Step 4: Verifying shapes...")
batch_size = 3
max_len = 6 # Longest sequence in batch
assert collated['input_ids'].shape == (batch_size, max_len), \
f"Expected ({batch_size}, {max_len}), got {collated['input_ids'].shape}"
assert collated['attention_mask'].shape == (batch_size, max_len), \
f"Expected ({batch_size}, {max_len}), got {collated['attention_mask'].shape}"
assert collated['risk_label'].shape == (batch_size,), \
f"Expected ({batch_size},), got {collated['risk_label'].shape}"
print(" βœ… All shapes correct!")
# Verify padding
print("\nπŸ” Step 5: Verifying padding...")
print(f" Sample 0 (originally length 4):")
print(f" input_ids: {collated['input_ids'][0]}")
print(f" attention_mask: {collated['attention_mask'][0]}")
print(f" Expected: 2 padding tokens at end")
# Check that padding was added correctly
assert collated['input_ids'][0][-2:].sum() == 0, "Padding tokens should be 0"
assert collated['attention_mask'][0][-2:].sum() == 0, "Padding mask should be 0"
print(f"\n Sample 2 (originally length 2):")
print(f" input_ids: {collated['input_ids'][2]}")
print(f" attention_mask: {collated['attention_mask'][2]}")
print(f" Expected: 4 padding tokens at end")
assert collated['input_ids'][2][-4:].sum() == 0, "Padding tokens should be 0"
assert collated['attention_mask'][2][-4:].sum() == 0, "Padding mask should be 0"
print(" βœ… Padding is correct!")
# Verify labels are preserved
print("\nπŸ” Step 6: Verifying labels preserved...")
assert collated['risk_label'][0] == 0
assert collated['risk_label'][1] == 1
assert collated['risk_label'][2] == 2
assert torch.allclose(collated['severity_score'][0], torch.tensor(5.5))
assert torch.allclose(collated['severity_score'][1], torch.tensor(7.2))
print(" βœ… All labels preserved correctly!")
# Test with DataLoader
print("\nπŸš€ Step 7: Testing with actual DataLoader...")
from torch.utils.data import Dataset, DataLoader
class MockDataset(Dataset):
def __len__(self):
return 10
def __getitem__(self, idx):
# Simulate variable lengths
length = torch.randint(5, 15, (1,)).item()
return {
'input_ids': torch.randint(0, 1000, (length,)),
'attention_mask': torch.ones(length, dtype=torch.long),
'risk_label': torch.tensor(idx % 3),
'severity_score': torch.tensor(float(idx)),
'importance_score': torch.tensor(float(idx) * 0.5)
}
dataset = MockDataset()
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=False,
collate_fn=collate_batch
)
print(f" βœ… Created DataLoader with batch_size=4")
# Test iteration
print(f"\n Testing iteration over batches...")
for batch_idx, batch in enumerate(dataloader):
print(f" Batch {batch_idx}: input_ids shape = {batch['input_ids'].shape}")
# Verify all tensors in batch have same size
batch_size = batch['input_ids'].shape[0]
seq_len = batch['input_ids'].shape[1]
assert batch['attention_mask'].shape == (batch_size, seq_len)
assert batch['risk_label'].shape == (batch_size,)
if batch_idx >= 1: # Just test first 2 batches
break
print(f" βœ… DataLoader iteration successful!")
print("\n" + "=" * 60)
print("πŸŽ‰ ALL TESTS PASSED!")
print("=" * 60)
print("\nβœ… collate_batch works correctly")
print("βœ… Handles variable-length sequences")
print("βœ… Pads to maximum length in batch")
print("βœ… Preserves all labels and scores")
print("βœ… Compatible with DataLoader")
print("\nπŸš€ Ready to run: python3 train.py")
except ImportError as e:
print(f"\n❌ Import error: {e}")
print(" Make sure torch is installed")
exit(1)
except AssertionError as e:
print(f"\n❌ Assertion failed: {e}")
exit(1)
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
exit(1)