File size: 6,457 Bytes
9b1c753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
"""
Test custom collate_batch function for variable-length sequences
"""
print("=" * 60)
print("Testing collate_batch for Variable-Length Sequences")
print("=" * 60)
try:
import torch
from trainer import collate_batch
print("\nβ
Step 1: Imports successful")
# Create mock batch with variable lengths
print("\nπ§ Step 2: Creating mock batch with variable lengths...")
batch = [
{
'input_ids': torch.tensor([101, 2023, 2003, 102]), # Length 4
'attention_mask': torch.tensor([1, 1, 1, 1]),
'risk_label': torch.tensor(0),
'severity_score': torch.tensor(5.5),
'importance_score': torch.tensor(3.2)
},
{
'input_ids': torch.tensor([101, 1996, 2172, 3325, 2003, 102]), # Length 6
'attention_mask': torch.tensor([1, 1, 1, 1, 1, 1]),
'risk_label': torch.tensor(1),
'severity_score': torch.tensor(7.2),
'importance_score': torch.tensor(4.8)
},
{
'input_ids': torch.tensor([101, 102]), # Length 2
'attention_mask': torch.tensor([1, 1]),
'risk_label': torch.tensor(2),
'severity_score': torch.tensor(3.1),
'importance_score': torch.tensor(2.5)
}
]
print(f" β
Created batch with 3 samples")
print(f" Sample 0: length {len(batch[0]['input_ids'])}")
print(f" Sample 1: length {len(batch[1]['input_ids'])}")
print(f" Sample 2: length {len(batch[2]['input_ids'])}")
# Test collate_batch
print("\nπ¦ Step 3: Testing collate_batch()...")
collated = collate_batch(batch)
print(f" β
Collate successful!")
print(f"\n π Output shapes:")
print(f" input_ids: {collated['input_ids'].shape}")
print(f" attention_mask: {collated['attention_mask'].shape}")
print(f" risk_label: {collated['risk_label'].shape}")
print(f" severity_score: {collated['severity_score'].shape}")
print(f" importance_score: {collated['importance_score'].shape}")
# Verify shapes
print("\nπ Step 4: Verifying shapes...")
batch_size = 3
max_len = 6 # Longest sequence in batch
assert collated['input_ids'].shape == (batch_size, max_len), \
f"Expected ({batch_size}, {max_len}), got {collated['input_ids'].shape}"
assert collated['attention_mask'].shape == (batch_size, max_len), \
f"Expected ({batch_size}, {max_len}), got {collated['attention_mask'].shape}"
assert collated['risk_label'].shape == (batch_size,), \
f"Expected ({batch_size},), got {collated['risk_label'].shape}"
print(" β
All shapes correct!")
# Verify padding
print("\nπ Step 5: Verifying padding...")
print(f" Sample 0 (originally length 4):")
print(f" input_ids: {collated['input_ids'][0]}")
print(f" attention_mask: {collated['attention_mask'][0]}")
print(f" Expected: 2 padding tokens at end")
# Check that padding was added correctly
assert collated['input_ids'][0][-2:].sum() == 0, "Padding tokens should be 0"
assert collated['attention_mask'][0][-2:].sum() == 0, "Padding mask should be 0"
print(f"\n Sample 2 (originally length 2):")
print(f" input_ids: {collated['input_ids'][2]}")
print(f" attention_mask: {collated['attention_mask'][2]}")
print(f" Expected: 4 padding tokens at end")
assert collated['input_ids'][2][-4:].sum() == 0, "Padding tokens should be 0"
assert collated['attention_mask'][2][-4:].sum() == 0, "Padding mask should be 0"
print(" β
Padding is correct!")
# Verify labels are preserved
print("\nπ Step 6: Verifying labels preserved...")
assert collated['risk_label'][0] == 0
assert collated['risk_label'][1] == 1
assert collated['risk_label'][2] == 2
assert torch.allclose(collated['severity_score'][0], torch.tensor(5.5))
assert torch.allclose(collated['severity_score'][1], torch.tensor(7.2))
print(" β
All labels preserved correctly!")
# Test with DataLoader
print("\nπ Step 7: Testing with actual DataLoader...")
from torch.utils.data import Dataset, DataLoader
class MockDataset(Dataset):
def __len__(self):
return 10
def __getitem__(self, idx):
# Simulate variable lengths
length = torch.randint(5, 15, (1,)).item()
return {
'input_ids': torch.randint(0, 1000, (length,)),
'attention_mask': torch.ones(length, dtype=torch.long),
'risk_label': torch.tensor(idx % 3),
'severity_score': torch.tensor(float(idx)),
'importance_score': torch.tensor(float(idx) * 0.5)
}
dataset = MockDataset()
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=False,
collate_fn=collate_batch
)
print(f" β
Created DataLoader with batch_size=4")
# Test iteration
print(f"\n Testing iteration over batches...")
for batch_idx, batch in enumerate(dataloader):
print(f" Batch {batch_idx}: input_ids shape = {batch['input_ids'].shape}")
# Verify all tensors in batch have same size
batch_size = batch['input_ids'].shape[0]
seq_len = batch['input_ids'].shape[1]
assert batch['attention_mask'].shape == (batch_size, seq_len)
assert batch['risk_label'].shape == (batch_size,)
if batch_idx >= 1: # Just test first 2 batches
break
print(f" β
DataLoader iteration successful!")
print("\n" + "=" * 60)
print("π ALL TESTS PASSED!")
print("=" * 60)
print("\nβ
collate_batch works correctly")
print("β
Handles variable-length sequences")
print("β
Pads to maximum length in batch")
print("β
Preserves all labels and scores")
print("β
Compatible with DataLoader")
print("\nπ Ready to run: python3 train.py")
except ImportError as e:
print(f"\nβ Import error: {e}")
print(" Make sure torch is installed")
exit(1)
except AssertionError as e:
print(f"\nβ Assertion failed: {e}")
exit(1)
except Exception as e:
print(f"\nβ Test failed: {e}")
import traceback
traceback.print_exc()
exit(1)
|