File size: 8,864 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# πŸ”§ Training Fix - Variable-Length Sequence Padding

## ❌ Problem Identified

```python
RuntimeError: stack expects each tensor to be equal size, 
but got [7] at entry 0 and [41] at entry 1
```

**Root Cause:** The DataLoader's default collate function tries to stack tensors of different lengths into a batch, which fails. This happens because different clauses tokenize to different sequence lengths (e.g., 7 tokens vs 41 tokens).

**Location:** Training loop β†’ DataLoader β†’ `torch.stack()`

---

## 🎯 Understanding the Problem

### **What Happens:**

1. **Tokenization** produces variable-length sequences:
   ```python
   Clause 1: "The party shall..." β†’ [101, 2023, 102] (length 7)
   Clause 2: "This agreement shall be governed by..." β†’ [101, ..., 102] (length 41)
   ```

2. **Default collate** tries to stack them:
   ```python
   torch.stack([tensor([...], length=7), tensor([..., length=41)])
   ❌ ERROR: Can't stack different sizes!
   ```

3. **Training fails** before first batch completes

### **Why It Matters:**

- BERT/Transformers require **fixed-size batches**
- Each sequence in a batch must have **same length**
- Solution: **Pad shorter sequences** to match longest in batch

---

## βœ… Solution Applied

### **Custom Collate Function**

Added `collate_batch()` function in `trainer.py`:

```python
def collate_batch(batch):
    """
    Custom collate function to handle variable-length sequences.
    Pads all sequences to the maximum length in the batch.
    """
    # Find max length in this batch
    max_len = max(item['input_ids'].size(0) for item in batch)
    
    # Prepare batched tensors
    input_ids_batch = []
    attention_mask_batch = []
    risk_labels_batch = []
    severity_scores_batch = []
    importance_scores_batch = []
    
    for item in batch:
        input_ids = item['input_ids']
        attention_mask = item['attention_mask']
        current_len = input_ids.size(0)
        
        # Pad if needed
        if current_len < max_len:
            padding_len = max_len - current_len
            
            # Pad input_ids with 0 (PAD token)
            input_ids = torch.cat([
                input_ids, 
                torch.zeros(padding_len, dtype=torch.long)
            ])
            
            # Pad attention_mask with 0 (don't attend to padding)
            attention_mask = torch.cat([
                attention_mask, 
                torch.zeros(padding_len, dtype=torch.long)
            ])
        
        input_ids_batch.append(input_ids)
        attention_mask_batch.append(attention_mask)
        risk_labels_batch.append(item['risk_label'])
        severity_scores_batch.append(item['severity_score'])
        importance_scores_batch.append(item['importance_score'])
    
    # Stack into batched tensors
    return {
        'input_ids': torch.stack(input_ids_batch),
        'attention_mask': torch.stack(attention_mask_batch),
        'risk_label': torch.stack(risk_labels_batch),
        'severity_score': torch.stack(severity_scores_batch),
        'importance_score': torch.stack(importance_scores_batch)
    }
```

### **Updated DataLoader Creation**

```python
dataloader = DataLoader(
    dataset,
    batch_size=self.config.batch_size,
    shuffle=shuffle,
    num_workers=0,
    collate_fn=collate_batch  # βœ… Custom collate function
)
```

---

## 🎯 How It Works

### **Example Batch:**

**Input:**
```
Sample 0: [101, 2023, 2003, 102]           length=4
Sample 1: [101, 1996, 2172, 3325, 2003, 102]  length=6
Sample 2: [101, 102]                       length=2
```

**After Padding (max_len=6):**
```
Sample 0: [101, 2023, 2003, 102, 0, 0]    length=6 βœ…
Sample 1: [101, 1996, 2172, 3325, 2003, 102]  length=6 βœ…
Sample 2: [101, 102, 0, 0, 0, 0]          length=6 βœ…
```

**Attention Masks:**
```
Sample 0: [1, 1, 1, 1, 0, 0]    # Don't attend to padding
Sample 1: [1, 1, 1, 1, 1, 1]    # Attend to all
Sample 2: [1, 1, 0, 0, 0, 0]    # Don't attend to padding
```

**Batched Tensor:**
```python
input_ids.shape = (3, 6)         # batch_size=3, max_len=6
attention_mask.shape = (3, 6)
risk_label.shape = (3,)
severity_score.shape = (3,)
importance_score.shape = (3,)
```

---

## πŸ” Key Features

### **1. Dynamic Padding**
- Pads to **max length in current batch** (not global max)
- More efficient: batch with short sequences β†’ less padding
- Example:
  - Batch 1: lengths [10, 12, 11] β†’ pad to 12
  - Batch 2: lengths [50, 48, 52] β†’ pad to 52

### **2. Attention Mask Handling**
- Original tokens: `attention_mask = 1` (attend)
- Padding tokens: `attention_mask = 0` (ignore)
- BERT won't process padding tokens

### **3. Preserves All Data**
- Risk labels preserved
- Severity scores preserved  
- Importance scores preserved
- No data loss, just shape adjustment

---

## πŸ§ͺ Verification

### **Test Script: `test_collate_batch.py`**

Tests 7 scenarios:
1. βœ… Import collate_batch function
2. βœ… Create mock batch with variable lengths (4, 6, 2)
3. βœ… Collate into uniform batch
4. βœ… Verify output shapes (3, 6)
5. βœ… Verify padding tokens are 0
6. βœ… Verify labels preserved
7. βœ… Test with actual DataLoader

**Run test:**
```bash
python3 test_collate_batch.py
```

**Expected output:**
```
βœ… Step 1: Imports successful
βœ… Step 2: Creating mock batch with variable lengths...
   Sample 0: length 4
   Sample 1: length 6
   Sample 2: length 2
βœ… Step 3: Testing collate_batch()...
   πŸ“Š Output shapes:
      input_ids: torch.Size([3, 6])
      attention_mask: torch.Size([3, 6])
βœ… Step 4: Verifying shapes...
βœ… Step 5: Verifying padding...
βœ… Step 6: Verifying labels preserved...
βœ… Step 7: Testing with actual DataLoader...

πŸŽ‰ ALL TESTS PASSED!
```

---

## πŸ“Š Performance Impact

### **Memory Usage:**

**Before (would crash):**
- Can't create batches ❌

**After:**
- Batch 1 (short clauses): `batch_size Γ— 20 tokens` = minimal
- Batch 2 (long clauses): `batch_size Γ— 512 tokens` = max
- **Adaptive:** Uses only what's needed per batch

### **Training Speed:**

- **No slowdown:** Padding is fast (`torch.cat`)
- **Actually faster:** Proper batching enables GPU parallelism
- **Efficient:** Only pads to batch max, not global max

---

## πŸŽ“ Technical Details

### **Why Attention Mask Matters:**

```python
# Without attention mask (wrong):
BERT attends to padding β†’ learns meaningless patterns ❌

# With attention mask (correct):
BERT ignores padding β†’ learns only from real tokens βœ…
```

### **PyTorch DataLoader Flow:**

```
Dataset.__getitem__(idx) 
  β†’ Returns single sample with variable length
  ↓
collate_fn(batch)
  β†’ Pads all samples to same length
  ↓
Batched tensors
  β†’ All same size, ready for model
```

### **Padding Token ID:**

- **0** is standard PAD token in BERT
- Doesn't conflict with vocabulary (reserved ID)
- Attention mask ensures it's ignored anyway

---

## πŸš€ Ready to Train

Now training will proceed normally:

```bash
python3 train.py
```

**Expected flow:**
```
1. Load data βœ…
2. Discover risk patterns (LDA) βœ…
3. Create datasets βœ…
4. Create dataloaders with collate_fn βœ…
5. Start training...
   πŸ“ˆ Epoch 1/5
   Batch 0: input_ids shape = [16, 235] βœ…
   Batch 1: input_ids shape = [16, 178] βœ…
   Batch 2: input_ids shape = [16, 312] βœ…
   ...
```

---

## πŸ”§ Alternative Solutions (Not Used)

### **1. Pre-pad to Fixed Length**
```python
# Pad ALL sequences to max_length=512
# Wasteful: short clauses pad to 512 unnecessarily
```
**Why not:** Wastes memory and computation

### **2. Filter by Length**
```python
# Only use clauses within certain length range
# Loses data
```
**Why not:** Loses valuable training data

### **3. Bucket Batching**
```python
# Group similar-length sequences into batches
# More complex
```
**Why not:** Our solution is simpler and works well

---

## πŸ“ Summary

### **Problem:**
- Variable-length sequences can't be stacked into batches
- Training crashes at first batch

### **Solution:**
- Custom `collate_batch()` function
- Dynamically pads to max length in each batch
- Preserves attention masks and all labels

### **Result:**
- βœ… Training proceeds normally
- βœ… Efficient memory usage
- βœ… No data loss
- βœ… Proper attention mask handling

---

## πŸ“š Files Modified

1. **`trainer.py`** - Added collate_batch function (45 lines)
   - Lines 18-61: `collate_batch()` function
   - Line 202: Added `collate_fn=collate_batch`

2. **`test_collate_batch.py`** - Verification test (180 lines)

3. **`doc/TRAINING_FIX_PADDING.md`** - This documentation

---

## βœ… Status

- [x] Identified root cause (variable-length tensors)
- [x] Implemented custom collate function
- [x] Added to DataLoader creation
- [x] Created comprehensive tests
- [x] Documented solution
- [x] **READY TO TRAIN** πŸŽ‰

---

**Next:** Run `python3 train.py` - training should now progress through epochs! πŸš€