code2-repo / doc /TRAINING_FIX_PADDING.md
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified
# πŸ”§ Training Fix - Variable-Length Sequence Padding
## ❌ Problem Identified
```python
RuntimeError: stack expects each tensor to be equal size,
but got [7] at entry 0 and [41] at entry 1
```
**Root Cause:** The DataLoader's default collate function tries to stack tensors of different lengths into a batch, which fails. This happens because different clauses tokenize to different sequence lengths (e.g., 7 tokens vs 41 tokens).
**Location:** Training loop β†’ DataLoader β†’ `torch.stack()`
---
## 🎯 Understanding the Problem
### **What Happens:**
1. **Tokenization** produces variable-length sequences:
```python
Clause 1: "The party shall..." β†’ [101, 2023, 102] (length 7)
Clause 2: "This agreement shall be governed by..." β†’ [101, ..., 102] (length 41)
```
2. **Default collate** tries to stack them:
```python
torch.stack([tensor([...], length=7), tensor([..., length=41)])
❌ ERROR: Can't stack different sizes!
```
3. **Training fails** before first batch completes
### **Why It Matters:**
- BERT/Transformers require **fixed-size batches**
- Each sequence in a batch must have **same length**
- Solution: **Pad shorter sequences** to match longest in batch
---
## βœ… Solution Applied
### **Custom Collate Function**
Added `collate_batch()` function in `trainer.py`:
```python
def collate_batch(batch):
"""
Custom collate function to handle variable-length sequences.
Pads all sequences to the maximum length in the batch.
"""
# Find max length in this batch
max_len = max(item['input_ids'].size(0) for item in batch)
# Prepare batched tensors
input_ids_batch = []
attention_mask_batch = []
risk_labels_batch = []
severity_scores_batch = []
importance_scores_batch = []
for item in batch:
input_ids = item['input_ids']
attention_mask = item['attention_mask']
current_len = input_ids.size(0)
# Pad if needed
if current_len < max_len:
padding_len = max_len - current_len
# Pad input_ids with 0 (PAD token)
input_ids = torch.cat([
input_ids,
torch.zeros(padding_len, dtype=torch.long)
])
# Pad attention_mask with 0 (don't attend to padding)
attention_mask = torch.cat([
attention_mask,
torch.zeros(padding_len, dtype=torch.long)
])
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
risk_labels_batch.append(item['risk_label'])
severity_scores_batch.append(item['severity_score'])
importance_scores_batch.append(item['importance_score'])
# Stack into batched tensors
return {
'input_ids': torch.stack(input_ids_batch),
'attention_mask': torch.stack(attention_mask_batch),
'risk_label': torch.stack(risk_labels_batch),
'severity_score': torch.stack(severity_scores_batch),
'importance_score': torch.stack(importance_scores_batch)
}
```
### **Updated DataLoader Creation**
```python
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=collate_batch # βœ… Custom collate function
)
```
---
## 🎯 How It Works
### **Example Batch:**
**Input:**
```
Sample 0: [101, 2023, 2003, 102] length=4
Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6
Sample 2: [101, 102] length=2
```
**After Padding (max_len=6):**
```
Sample 0: [101, 2023, 2003, 102, 0, 0] length=6 βœ…
Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6 βœ…
Sample 2: [101, 102, 0, 0, 0, 0] length=6 βœ…
```
**Attention Masks:**
```
Sample 0: [1, 1, 1, 1, 0, 0] # Don't attend to padding
Sample 1: [1, 1, 1, 1, 1, 1] # Attend to all
Sample 2: [1, 1, 0, 0, 0, 0] # Don't attend to padding
```
**Batched Tensor:**
```python
input_ids.shape = (3, 6) # batch_size=3, max_len=6
attention_mask.shape = (3, 6)
risk_label.shape = (3,)
severity_score.shape = (3,)
importance_score.shape = (3,)
```
---
## πŸ” Key Features
### **1. Dynamic Padding**
- Pads to **max length in current batch** (not global max)
- More efficient: batch with short sequences β†’ less padding
- Example:
- Batch 1: lengths [10, 12, 11] β†’ pad to 12
- Batch 2: lengths [50, 48, 52] β†’ pad to 52
### **2. Attention Mask Handling**
- Original tokens: `attention_mask = 1` (attend)
- Padding tokens: `attention_mask = 0` (ignore)
- BERT won't process padding tokens
### **3. Preserves All Data**
- Risk labels preserved
- Severity scores preserved
- Importance scores preserved
- No data loss, just shape adjustment
---
## πŸ§ͺ Verification
### **Test Script: `test_collate_batch.py`**
Tests 7 scenarios:
1. βœ… Import collate_batch function
2. βœ… Create mock batch with variable lengths (4, 6, 2)
3. βœ… Collate into uniform batch
4. βœ… Verify output shapes (3, 6)
5. βœ… Verify padding tokens are 0
6. βœ… Verify labels preserved
7. βœ… Test with actual DataLoader
**Run test:**
```bash
python3 test_collate_batch.py
```
**Expected output:**
```
βœ… Step 1: Imports successful
βœ… Step 2: Creating mock batch with variable lengths...
Sample 0: length 4
Sample 1: length 6
Sample 2: length 2
βœ… Step 3: Testing collate_batch()...
πŸ“Š Output shapes:
input_ids: torch.Size([3, 6])
attention_mask: torch.Size([3, 6])
βœ… Step 4: Verifying shapes...
βœ… Step 5: Verifying padding...
βœ… Step 6: Verifying labels preserved...
βœ… Step 7: Testing with actual DataLoader...
πŸŽ‰ ALL TESTS PASSED!
```
---
## πŸ“Š Performance Impact
### **Memory Usage:**
**Before (would crash):**
- Can't create batches ❌
**After:**
- Batch 1 (short clauses): `batch_size Γ— 20 tokens` = minimal
- Batch 2 (long clauses): `batch_size Γ— 512 tokens` = max
- **Adaptive:** Uses only what's needed per batch
### **Training Speed:**
- **No slowdown:** Padding is fast (`torch.cat`)
- **Actually faster:** Proper batching enables GPU parallelism
- **Efficient:** Only pads to batch max, not global max
---
## πŸŽ“ Technical Details
### **Why Attention Mask Matters:**
```python
# Without attention mask (wrong):
BERT attends to padding β†’ learns meaningless patterns ❌
# With attention mask (correct):
BERT ignores padding β†’ learns only from real tokens βœ…
```
### **PyTorch DataLoader Flow:**
```
Dataset.__getitem__(idx)
β†’ Returns single sample with variable length
↓
collate_fn(batch)
β†’ Pads all samples to same length
↓
Batched tensors
β†’ All same size, ready for model
```
### **Padding Token ID:**
- **0** is standard PAD token in BERT
- Doesn't conflict with vocabulary (reserved ID)
- Attention mask ensures it's ignored anyway
---
## πŸš€ Ready to Train
Now training will proceed normally:
```bash
python3 train.py
```
**Expected flow:**
```
1. Load data βœ…
2. Discover risk patterns (LDA) βœ…
3. Create datasets βœ…
4. Create dataloaders with collate_fn βœ…
5. Start training...
πŸ“ˆ Epoch 1/5
Batch 0: input_ids shape = [16, 235] βœ…
Batch 1: input_ids shape = [16, 178] βœ…
Batch 2: input_ids shape = [16, 312] βœ…
...
```
---
## πŸ”§ Alternative Solutions (Not Used)
### **1. Pre-pad to Fixed Length**
```python
# Pad ALL sequences to max_length=512
# Wasteful: short clauses pad to 512 unnecessarily
```
**Why not:** Wastes memory and computation
### **2. Filter by Length**
```python
# Only use clauses within certain length range
# Loses data
```
**Why not:** Loses valuable training data
### **3. Bucket Batching**
```python
# Group similar-length sequences into batches
# More complex
```
**Why not:** Our solution is simpler and works well
---
## πŸ“ Summary
### **Problem:**
- Variable-length sequences can't be stacked into batches
- Training crashes at first batch
### **Solution:**
- Custom `collate_batch()` function
- Dynamically pads to max length in each batch
- Preserves attention masks and all labels
### **Result:**
- βœ… Training proceeds normally
- βœ… Efficient memory usage
- βœ… No data loss
- βœ… Proper attention mask handling
---
## πŸ“š Files Modified
1. **`trainer.py`** - Added collate_batch function (45 lines)
- Lines 18-61: `collate_batch()` function
- Line 202: Added `collate_fn=collate_batch`
2. **`test_collate_batch.py`** - Verification test (180 lines)
3. **`doc/TRAINING_FIX_PADDING.md`** - This documentation
---
## βœ… Status
- [x] Identified root cause (variable-length tensors)
- [x] Implemented custom collate function
- [x] Added to DataLoader creation
- [x] Created comprehensive tests
- [x] Documented solution
- [x] **READY TO TRAIN** πŸŽ‰
---
**Next:** Run `python3 train.py` - training should now progress through epochs! πŸš€