File size: 6,457 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Test custom collate_batch function for variable-length sequences
"""

print("=" * 60)
print("Testing collate_batch for Variable-Length Sequences")
print("=" * 60)

try:
    import torch
    from trainer import collate_batch
    
    print("\nβœ… Step 1: Imports successful")
    
    # Create mock batch with variable lengths
    print("\nπŸ”§ Step 2: Creating mock batch with variable lengths...")
    batch = [
        {
            'input_ids': torch.tensor([101, 2023, 2003, 102]),  # Length 4
            'attention_mask': torch.tensor([1, 1, 1, 1]),
            'risk_label': torch.tensor(0),
            'severity_score': torch.tensor(5.5),
            'importance_score': torch.tensor(3.2)
        },
        {
            'input_ids': torch.tensor([101, 1996, 2172, 3325, 2003, 102]),  # Length 6
            'attention_mask': torch.tensor([1, 1, 1, 1, 1, 1]),
            'risk_label': torch.tensor(1),
            'severity_score': torch.tensor(7.2),
            'importance_score': torch.tensor(4.8)
        },
        {
            'input_ids': torch.tensor([101, 102]),  # Length 2
            'attention_mask': torch.tensor([1, 1]),
            'risk_label': torch.tensor(2),
            'severity_score': torch.tensor(3.1),
            'importance_score': torch.tensor(2.5)
        }
    ]
    
    print(f"   βœ… Created batch with 3 samples")
    print(f"      Sample 0: length {len(batch[0]['input_ids'])}")
    print(f"      Sample 1: length {len(batch[1]['input_ids'])}")
    print(f"      Sample 2: length {len(batch[2]['input_ids'])}")
    
    # Test collate_batch
    print("\nπŸ“¦ Step 3: Testing collate_batch()...")
    collated = collate_batch(batch)
    
    print(f"   βœ… Collate successful!")
    print(f"\n   πŸ“Š Output shapes:")
    print(f"      input_ids: {collated['input_ids'].shape}")
    print(f"      attention_mask: {collated['attention_mask'].shape}")
    print(f"      risk_label: {collated['risk_label'].shape}")
    print(f"      severity_score: {collated['severity_score'].shape}")
    print(f"      importance_score: {collated['importance_score'].shape}")
    
    # Verify shapes
    print("\nπŸ” Step 4: Verifying shapes...")
    batch_size = 3
    max_len = 6  # Longest sequence in batch
    
    assert collated['input_ids'].shape == (batch_size, max_len), \
        f"Expected ({batch_size}, {max_len}), got {collated['input_ids'].shape}"
    assert collated['attention_mask'].shape == (batch_size, max_len), \
        f"Expected ({batch_size}, {max_len}), got {collated['attention_mask'].shape}"
    assert collated['risk_label'].shape == (batch_size,), \
        f"Expected ({batch_size},), got {collated['risk_label'].shape}"
    
    print("   βœ… All shapes correct!")
    
    # Verify padding
    print("\nπŸ” Step 5: Verifying padding...")
    print(f"   Sample 0 (originally length 4):")
    print(f"      input_ids: {collated['input_ids'][0]}")
    print(f"      attention_mask: {collated['attention_mask'][0]}")
    print(f"      Expected: 2 padding tokens at end")
    
    # Check that padding was added correctly
    assert collated['input_ids'][0][-2:].sum() == 0, "Padding tokens should be 0"
    assert collated['attention_mask'][0][-2:].sum() == 0, "Padding mask should be 0"
    
    print(f"\n   Sample 2 (originally length 2):")
    print(f"      input_ids: {collated['input_ids'][2]}")
    print(f"      attention_mask: {collated['attention_mask'][2]}")
    print(f"      Expected: 4 padding tokens at end")
    
    assert collated['input_ids'][2][-4:].sum() == 0, "Padding tokens should be 0"
    assert collated['attention_mask'][2][-4:].sum() == 0, "Padding mask should be 0"
    
    print("   βœ… Padding is correct!")
    
    # Verify labels are preserved
    print("\nπŸ” Step 6: Verifying labels preserved...")
    assert collated['risk_label'][0] == 0
    assert collated['risk_label'][1] == 1
    assert collated['risk_label'][2] == 2
    assert torch.allclose(collated['severity_score'][0], torch.tensor(5.5))
    assert torch.allclose(collated['severity_score'][1], torch.tensor(7.2))
    print("   βœ… All labels preserved correctly!")
    
    # Test with DataLoader
    print("\nπŸš€ Step 7: Testing with actual DataLoader...")
    from torch.utils.data import Dataset, DataLoader
    
    class MockDataset(Dataset):
        def __len__(self):
            return 10
        
        def __getitem__(self, idx):
            # Simulate variable lengths
            length = torch.randint(5, 15, (1,)).item()
            return {
                'input_ids': torch.randint(0, 1000, (length,)),
                'attention_mask': torch.ones(length, dtype=torch.long),
                'risk_label': torch.tensor(idx % 3),
                'severity_score': torch.tensor(float(idx)),
                'importance_score': torch.tensor(float(idx) * 0.5)
            }
    
    dataset = MockDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=False,
        collate_fn=collate_batch
    )
    
    print(f"   βœ… Created DataLoader with batch_size=4")
    
    # Test iteration
    print(f"\n   Testing iteration over batches...")
    for batch_idx, batch in enumerate(dataloader):
        print(f"      Batch {batch_idx}: input_ids shape = {batch['input_ids'].shape}")
        
        # Verify all tensors in batch have same size
        batch_size = batch['input_ids'].shape[0]
        seq_len = batch['input_ids'].shape[1]
        
        assert batch['attention_mask'].shape == (batch_size, seq_len)
        assert batch['risk_label'].shape == (batch_size,)
        
        if batch_idx >= 1:  # Just test first 2 batches
            break
    
    print(f"   βœ… DataLoader iteration successful!")
    
    print("\n" + "=" * 60)
    print("πŸŽ‰ ALL TESTS PASSED!")
    print("=" * 60)
    print("\nβœ… collate_batch works correctly")
    print("βœ… Handles variable-length sequences")
    print("βœ… Pads to maximum length in batch")
    print("βœ… Preserves all labels and scores")
    print("βœ… Compatible with DataLoader")
    print("\nπŸš€ Ready to run: python3 train.py")
    
except ImportError as e:
    print(f"\n❌ Import error: {e}")
    print("   Make sure torch is installed")
    exit(1)
    
except AssertionError as e:
    print(f"\n❌ Assertion failed: {e}")
    exit(1)
    
except Exception as e:
    print(f"\n❌ Test failed: {e}")
    import traceback
    traceback.print_exc()
    exit(1)