File size: 8,864 Bytes
9b1c753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
# π§ Training Fix - Variable-Length Sequence Padding
## β Problem Identified
```python
RuntimeError: stack expects each tensor to be equal size,
but got [7] at entry 0 and [41] at entry 1
```
**Root Cause:** The DataLoader's default collate function tries to stack tensors of different lengths into a batch, which fails. This happens because different clauses tokenize to different sequence lengths (e.g., 7 tokens vs 41 tokens).
**Location:** Training loop β DataLoader β `torch.stack()`
---
## π― Understanding the Problem
### **What Happens:**
1. **Tokenization** produces variable-length sequences:
```python
Clause 1: "The party shall..." β [101, 2023, 102] (length 7)
Clause 2: "This agreement shall be governed by..." β [101, ..., 102] (length 41)
```
2. **Default collate** tries to stack them:
```python
torch.stack([tensor([...], length=7), tensor([..., length=41)])
β ERROR: Can't stack different sizes!
```
3. **Training fails** before first batch completes
### **Why It Matters:**
- BERT/Transformers require **fixed-size batches**
- Each sequence in a batch must have **same length**
- Solution: **Pad shorter sequences** to match longest in batch
---
## β
Solution Applied
### **Custom Collate Function**
Added `collate_batch()` function in `trainer.py`:
```python
def collate_batch(batch):
"""
Custom collate function to handle variable-length sequences.
Pads all sequences to the maximum length in the batch.
"""
# Find max length in this batch
max_len = max(item['input_ids'].size(0) for item in batch)
# Prepare batched tensors
input_ids_batch = []
attention_mask_batch = []
risk_labels_batch = []
severity_scores_batch = []
importance_scores_batch = []
for item in batch:
input_ids = item['input_ids']
attention_mask = item['attention_mask']
current_len = input_ids.size(0)
# Pad if needed
if current_len < max_len:
padding_len = max_len - current_len
# Pad input_ids with 0 (PAD token)
input_ids = torch.cat([
input_ids,
torch.zeros(padding_len, dtype=torch.long)
])
# Pad attention_mask with 0 (don't attend to padding)
attention_mask = torch.cat([
attention_mask,
torch.zeros(padding_len, dtype=torch.long)
])
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
risk_labels_batch.append(item['risk_label'])
severity_scores_batch.append(item['severity_score'])
importance_scores_batch.append(item['importance_score'])
# Stack into batched tensors
return {
'input_ids': torch.stack(input_ids_batch),
'attention_mask': torch.stack(attention_mask_batch),
'risk_label': torch.stack(risk_labels_batch),
'severity_score': torch.stack(severity_scores_batch),
'importance_score': torch.stack(importance_scores_batch)
}
```
### **Updated DataLoader Creation**
```python
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=collate_batch # β
Custom collate function
)
```
---
## π― How It Works
### **Example Batch:**
**Input:**
```
Sample 0: [101, 2023, 2003, 102] length=4
Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6
Sample 2: [101, 102] length=2
```
**After Padding (max_len=6):**
```
Sample 0: [101, 2023, 2003, 102, 0, 0] length=6 β
Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6 β
Sample 2: [101, 102, 0, 0, 0, 0] length=6 β
```
**Attention Masks:**
```
Sample 0: [1, 1, 1, 1, 0, 0] # Don't attend to padding
Sample 1: [1, 1, 1, 1, 1, 1] # Attend to all
Sample 2: [1, 1, 0, 0, 0, 0] # Don't attend to padding
```
**Batched Tensor:**
```python
input_ids.shape = (3, 6) # batch_size=3, max_len=6
attention_mask.shape = (3, 6)
risk_label.shape = (3,)
severity_score.shape = (3,)
importance_score.shape = (3,)
```
---
## π Key Features
### **1. Dynamic Padding**
- Pads to **max length in current batch** (not global max)
- More efficient: batch with short sequences β less padding
- Example:
- Batch 1: lengths [10, 12, 11] β pad to 12
- Batch 2: lengths [50, 48, 52] β pad to 52
### **2. Attention Mask Handling**
- Original tokens: `attention_mask = 1` (attend)
- Padding tokens: `attention_mask = 0` (ignore)
- BERT won't process padding tokens
### **3. Preserves All Data**
- Risk labels preserved
- Severity scores preserved
- Importance scores preserved
- No data loss, just shape adjustment
---
## π§ͺ Verification
### **Test Script: `test_collate_batch.py`**
Tests 7 scenarios:
1. β
Import collate_batch function
2. β
Create mock batch with variable lengths (4, 6, 2)
3. β
Collate into uniform batch
4. β
Verify output shapes (3, 6)
5. β
Verify padding tokens are 0
6. β
Verify labels preserved
7. β
Test with actual DataLoader
**Run test:**
```bash
python3 test_collate_batch.py
```
**Expected output:**
```
β
Step 1: Imports successful
β
Step 2: Creating mock batch with variable lengths...
Sample 0: length 4
Sample 1: length 6
Sample 2: length 2
β
Step 3: Testing collate_batch()...
π Output shapes:
input_ids: torch.Size([3, 6])
attention_mask: torch.Size([3, 6])
β
Step 4: Verifying shapes...
β
Step 5: Verifying padding...
β
Step 6: Verifying labels preserved...
β
Step 7: Testing with actual DataLoader...
π ALL TESTS PASSED!
```
---
## π Performance Impact
### **Memory Usage:**
**Before (would crash):**
- Can't create batches β
**After:**
- Batch 1 (short clauses): `batch_size Γ 20 tokens` = minimal
- Batch 2 (long clauses): `batch_size Γ 512 tokens` = max
- **Adaptive:** Uses only what's needed per batch
### **Training Speed:**
- **No slowdown:** Padding is fast (`torch.cat`)
- **Actually faster:** Proper batching enables GPU parallelism
- **Efficient:** Only pads to batch max, not global max
---
## π Technical Details
### **Why Attention Mask Matters:**
```python
# Without attention mask (wrong):
BERT attends to padding β learns meaningless patterns β
# With attention mask (correct):
BERT ignores padding β learns only from real tokens β
```
### **PyTorch DataLoader Flow:**
```
Dataset.__getitem__(idx)
β Returns single sample with variable length
β
collate_fn(batch)
β Pads all samples to same length
β
Batched tensors
β All same size, ready for model
```
### **Padding Token ID:**
- **0** is standard PAD token in BERT
- Doesn't conflict with vocabulary (reserved ID)
- Attention mask ensures it's ignored anyway
---
## π Ready to Train
Now training will proceed normally:
```bash
python3 train.py
```
**Expected flow:**
```
1. Load data β
2. Discover risk patterns (LDA) β
3. Create datasets β
4. Create dataloaders with collate_fn β
5. Start training...
π Epoch 1/5
Batch 0: input_ids shape = [16, 235] β
Batch 1: input_ids shape = [16, 178] β
Batch 2: input_ids shape = [16, 312] β
...
```
---
## π§ Alternative Solutions (Not Used)
### **1. Pre-pad to Fixed Length**
```python
# Pad ALL sequences to max_length=512
# Wasteful: short clauses pad to 512 unnecessarily
```
**Why not:** Wastes memory and computation
### **2. Filter by Length**
```python
# Only use clauses within certain length range
# Loses data
```
**Why not:** Loses valuable training data
### **3. Bucket Batching**
```python
# Group similar-length sequences into batches
# More complex
```
**Why not:** Our solution is simpler and works well
---
## π Summary
### **Problem:**
- Variable-length sequences can't be stacked into batches
- Training crashes at first batch
### **Solution:**
- Custom `collate_batch()` function
- Dynamically pads to max length in each batch
- Preserves attention masks and all labels
### **Result:**
- β
Training proceeds normally
- β
Efficient memory usage
- β
No data loss
- β
Proper attention mask handling
---
## π Files Modified
1. **`trainer.py`** - Added collate_batch function (45 lines)
- Lines 18-61: `collate_batch()` function
- Line 202: Added `collate_fn=collate_batch`
2. **`test_collate_batch.py`** - Verification test (180 lines)
3. **`doc/TRAINING_FIX_PADDING.md`** - This documentation
---
## β
Status
- [x] Identified root cause (variable-length tensors)
- [x] Implemented custom collate function
- [x] Added to DataLoader creation
- [x] Created comprehensive tests
- [x] Documented solution
- [x] **READY TO TRAIN** π
---
**Next:** Run `python3 train.py` - training should now progress through epochs! π
|