|
|
""" |
|
|
Test custom collate_batch function for variable-length sequences |
|
|
""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("Testing collate_batch for Variable-Length Sequences") |
|
|
print("=" * 60) |
|
|
|
|
|
try: |
|
|
import torch |
|
|
from trainer import collate_batch |
|
|
|
|
|
print("\nβ
Step 1: Imports successful") |
|
|
|
|
|
|
|
|
print("\nπ§ Step 2: Creating mock batch with variable lengths...") |
|
|
batch = [ |
|
|
{ |
|
|
'input_ids': torch.tensor([101, 2023, 2003, 102]), |
|
|
'attention_mask': torch.tensor([1, 1, 1, 1]), |
|
|
'risk_label': torch.tensor(0), |
|
|
'severity_score': torch.tensor(5.5), |
|
|
'importance_score': torch.tensor(3.2) |
|
|
}, |
|
|
{ |
|
|
'input_ids': torch.tensor([101, 1996, 2172, 3325, 2003, 102]), |
|
|
'attention_mask': torch.tensor([1, 1, 1, 1, 1, 1]), |
|
|
'risk_label': torch.tensor(1), |
|
|
'severity_score': torch.tensor(7.2), |
|
|
'importance_score': torch.tensor(4.8) |
|
|
}, |
|
|
{ |
|
|
'input_ids': torch.tensor([101, 102]), |
|
|
'attention_mask': torch.tensor([1, 1]), |
|
|
'risk_label': torch.tensor(2), |
|
|
'severity_score': torch.tensor(3.1), |
|
|
'importance_score': torch.tensor(2.5) |
|
|
} |
|
|
] |
|
|
|
|
|
print(f" β
Created batch with 3 samples") |
|
|
print(f" Sample 0: length {len(batch[0]['input_ids'])}") |
|
|
print(f" Sample 1: length {len(batch[1]['input_ids'])}") |
|
|
print(f" Sample 2: length {len(batch[2]['input_ids'])}") |
|
|
|
|
|
|
|
|
print("\nπ¦ Step 3: Testing collate_batch()...") |
|
|
collated = collate_batch(batch) |
|
|
|
|
|
print(f" β
Collate successful!") |
|
|
print(f"\n π Output shapes:") |
|
|
print(f" input_ids: {collated['input_ids'].shape}") |
|
|
print(f" attention_mask: {collated['attention_mask'].shape}") |
|
|
print(f" risk_label: {collated['risk_label'].shape}") |
|
|
print(f" severity_score: {collated['severity_score'].shape}") |
|
|
print(f" importance_score: {collated['importance_score'].shape}") |
|
|
|
|
|
|
|
|
print("\nπ Step 4: Verifying shapes...") |
|
|
batch_size = 3 |
|
|
max_len = 6 |
|
|
|
|
|
assert collated['input_ids'].shape == (batch_size, max_len), \ |
|
|
f"Expected ({batch_size}, {max_len}), got {collated['input_ids'].shape}" |
|
|
assert collated['attention_mask'].shape == (batch_size, max_len), \ |
|
|
f"Expected ({batch_size}, {max_len}), got {collated['attention_mask'].shape}" |
|
|
assert collated['risk_label'].shape == (batch_size,), \ |
|
|
f"Expected ({batch_size},), got {collated['risk_label'].shape}" |
|
|
|
|
|
print(" β
All shapes correct!") |
|
|
|
|
|
|
|
|
print("\nπ Step 5: Verifying padding...") |
|
|
print(f" Sample 0 (originally length 4):") |
|
|
print(f" input_ids: {collated['input_ids'][0]}") |
|
|
print(f" attention_mask: {collated['attention_mask'][0]}") |
|
|
print(f" Expected: 2 padding tokens at end") |
|
|
|
|
|
|
|
|
assert collated['input_ids'][0][-2:].sum() == 0, "Padding tokens should be 0" |
|
|
assert collated['attention_mask'][0][-2:].sum() == 0, "Padding mask should be 0" |
|
|
|
|
|
print(f"\n Sample 2 (originally length 2):") |
|
|
print(f" input_ids: {collated['input_ids'][2]}") |
|
|
print(f" attention_mask: {collated['attention_mask'][2]}") |
|
|
print(f" Expected: 4 padding tokens at end") |
|
|
|
|
|
assert collated['input_ids'][2][-4:].sum() == 0, "Padding tokens should be 0" |
|
|
assert collated['attention_mask'][2][-4:].sum() == 0, "Padding mask should be 0" |
|
|
|
|
|
print(" β
Padding is correct!") |
|
|
|
|
|
|
|
|
print("\nπ Step 6: Verifying labels preserved...") |
|
|
assert collated['risk_label'][0] == 0 |
|
|
assert collated['risk_label'][1] == 1 |
|
|
assert collated['risk_label'][2] == 2 |
|
|
assert torch.allclose(collated['severity_score'][0], torch.tensor(5.5)) |
|
|
assert torch.allclose(collated['severity_score'][1], torch.tensor(7.2)) |
|
|
print(" β
All labels preserved correctly!") |
|
|
|
|
|
|
|
|
print("\nπ Step 7: Testing with actual DataLoader...") |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
class MockDataset(Dataset): |
|
|
def __len__(self): |
|
|
return 10 |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
length = torch.randint(5, 15, (1,)).item() |
|
|
return { |
|
|
'input_ids': torch.randint(0, 1000, (length,)), |
|
|
'attention_mask': torch.ones(length, dtype=torch.long), |
|
|
'risk_label': torch.tensor(idx % 3), |
|
|
'severity_score': torch.tensor(float(idx)), |
|
|
'importance_score': torch.tensor(float(idx) * 0.5) |
|
|
} |
|
|
|
|
|
dataset = MockDataset() |
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=4, |
|
|
shuffle=False, |
|
|
collate_fn=collate_batch |
|
|
) |
|
|
|
|
|
print(f" β
Created DataLoader with batch_size=4") |
|
|
|
|
|
|
|
|
print(f"\n Testing iteration over batches...") |
|
|
for batch_idx, batch in enumerate(dataloader): |
|
|
print(f" Batch {batch_idx}: input_ids shape = {batch['input_ids'].shape}") |
|
|
|
|
|
|
|
|
batch_size = batch['input_ids'].shape[0] |
|
|
seq_len = batch['input_ids'].shape[1] |
|
|
|
|
|
assert batch['attention_mask'].shape == (batch_size, seq_len) |
|
|
assert batch['risk_label'].shape == (batch_size,) |
|
|
|
|
|
if batch_idx >= 1: |
|
|
break |
|
|
|
|
|
print(f" β
DataLoader iteration successful!") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("π ALL TESTS PASSED!") |
|
|
print("=" * 60) |
|
|
print("\nβ
collate_batch works correctly") |
|
|
print("β
Handles variable-length sequences") |
|
|
print("β
Pads to maximum length in batch") |
|
|
print("β
Preserves all labels and scores") |
|
|
print("β
Compatible with DataLoader") |
|
|
print("\nπ Ready to run: python3 train.py") |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"\nβ Import error: {e}") |
|
|
print(" Make sure torch is installed") |
|
|
exit(1) |
|
|
|
|
|
except AssertionError as e: |
|
|
print(f"\nβ Assertion failed: {e}") |
|
|
exit(1) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Test failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|