π§ Training Fix - Variable-Length Sequence Padding
β Problem Identified
RuntimeError: stack expects each tensor to be equal size,
but got [7] at entry 0 and [41] at entry 1
Root Cause: The DataLoader's default collate function tries to stack tensors of different lengths into a batch, which fails. This happens because different clauses tokenize to different sequence lengths (e.g., 7 tokens vs 41 tokens).
Location: Training loop β DataLoader β torch.stack()
π― Understanding the Problem
What Happens:
Tokenization produces variable-length sequences:
Clause 1: "The party shall..." β [101, 2023, 102] (length 7) Clause 2: "This agreement shall be governed by..." β [101, ..., 102] (length 41)Default collate tries to stack them:
torch.stack([tensor([...], length=7), tensor([..., length=41)]) β ERROR: Can't stack different sizes!Training fails before first batch completes
Why It Matters:
- BERT/Transformers require fixed-size batches
- Each sequence in a batch must have same length
- Solution: Pad shorter sequences to match longest in batch
β Solution Applied
Custom Collate Function
Added collate_batch() function in trainer.py:
def collate_batch(batch):
"""
Custom collate function to handle variable-length sequences.
Pads all sequences to the maximum length in the batch.
"""
# Find max length in this batch
max_len = max(item['input_ids'].size(0) for item in batch)
# Prepare batched tensors
input_ids_batch = []
attention_mask_batch = []
risk_labels_batch = []
severity_scores_batch = []
importance_scores_batch = []
for item in batch:
input_ids = item['input_ids']
attention_mask = item['attention_mask']
current_len = input_ids.size(0)
# Pad if needed
if current_len < max_len:
padding_len = max_len - current_len
# Pad input_ids with 0 (PAD token)
input_ids = torch.cat([
input_ids,
torch.zeros(padding_len, dtype=torch.long)
])
# Pad attention_mask with 0 (don't attend to padding)
attention_mask = torch.cat([
attention_mask,
torch.zeros(padding_len, dtype=torch.long)
])
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
risk_labels_batch.append(item['risk_label'])
severity_scores_batch.append(item['severity_score'])
importance_scores_batch.append(item['importance_score'])
# Stack into batched tensors
return {
'input_ids': torch.stack(input_ids_batch),
'attention_mask': torch.stack(attention_mask_batch),
'risk_label': torch.stack(risk_labels_batch),
'severity_score': torch.stack(severity_scores_batch),
'importance_score': torch.stack(importance_scores_batch)
}
Updated DataLoader Creation
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=collate_batch # β
Custom collate function
)
π― How It Works
Example Batch:
Input:
Sample 0: [101, 2023, 2003, 102] length=4
Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6
Sample 2: [101, 102] length=2
After Padding (max_len=6):
Sample 0: [101, 2023, 2003, 102, 0, 0] length=6 β
Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6 β
Sample 2: [101, 102, 0, 0, 0, 0] length=6 β
Attention Masks:
Sample 0: [1, 1, 1, 1, 0, 0] # Don't attend to padding
Sample 1: [1, 1, 1, 1, 1, 1] # Attend to all
Sample 2: [1, 1, 0, 0, 0, 0] # Don't attend to padding
Batched Tensor:
input_ids.shape = (3, 6) # batch_size=3, max_len=6
attention_mask.shape = (3, 6)
risk_label.shape = (3,)
severity_score.shape = (3,)
importance_score.shape = (3,)
π Key Features
1. Dynamic Padding
- Pads to max length in current batch (not global max)
- More efficient: batch with short sequences β less padding
- Example:
- Batch 1: lengths [10, 12, 11] β pad to 12
- Batch 2: lengths [50, 48, 52] β pad to 52
2. Attention Mask Handling
- Original tokens:
attention_mask = 1(attend) - Padding tokens:
attention_mask = 0(ignore) - BERT won't process padding tokens
3. Preserves All Data
- Risk labels preserved
- Severity scores preserved
- Importance scores preserved
- No data loss, just shape adjustment
π§ͺ Verification
Test Script: test_collate_batch.py
Tests 7 scenarios:
- β Import collate_batch function
- β Create mock batch with variable lengths (4, 6, 2)
- β Collate into uniform batch
- β Verify output shapes (3, 6)
- β Verify padding tokens are 0
- β Verify labels preserved
- β Test with actual DataLoader
Run test:
python3 test_collate_batch.py
Expected output:
β
Step 1: Imports successful
β
Step 2: Creating mock batch with variable lengths...
Sample 0: length 4
Sample 1: length 6
Sample 2: length 2
β
Step 3: Testing collate_batch()...
π Output shapes:
input_ids: torch.Size([3, 6])
attention_mask: torch.Size([3, 6])
β
Step 4: Verifying shapes...
β
Step 5: Verifying padding...
β
Step 6: Verifying labels preserved...
β
Step 7: Testing with actual DataLoader...
π ALL TESTS PASSED!
π Performance Impact
Memory Usage:
Before (would crash):
- Can't create batches β
After:
- Batch 1 (short clauses):
batch_size Γ 20 tokens= minimal - Batch 2 (long clauses):
batch_size Γ 512 tokens= max - Adaptive: Uses only what's needed per batch
Training Speed:
- No slowdown: Padding is fast (
torch.cat) - Actually faster: Proper batching enables GPU parallelism
- Efficient: Only pads to batch max, not global max
π Technical Details
Why Attention Mask Matters:
# Without attention mask (wrong):
BERT attends to padding β learns meaningless patterns β
# With attention mask (correct):
BERT ignores padding β learns only from real tokens β
PyTorch DataLoader Flow:
Dataset.__getitem__(idx)
β Returns single sample with variable length
β
collate_fn(batch)
β Pads all samples to same length
β
Batched tensors
β All same size, ready for model
Padding Token ID:
- 0 is standard PAD token in BERT
- Doesn't conflict with vocabulary (reserved ID)
- Attention mask ensures it's ignored anyway
π Ready to Train
Now training will proceed normally:
python3 train.py
Expected flow:
1. Load data β
2. Discover risk patterns (LDA) β
3. Create datasets β
4. Create dataloaders with collate_fn β
5. Start training...
π Epoch 1/5
Batch 0: input_ids shape = [16, 235] β
Batch 1: input_ids shape = [16, 178] β
Batch 2: input_ids shape = [16, 312] β
...
π§ Alternative Solutions (Not Used)
1. Pre-pad to Fixed Length
# Pad ALL sequences to max_length=512
# Wasteful: short clauses pad to 512 unnecessarily
Why not: Wastes memory and computation
2. Filter by Length
# Only use clauses within certain length range
# Loses data
Why not: Loses valuable training data
3. Bucket Batching
# Group similar-length sequences into batches
# More complex
Why not: Our solution is simpler and works well
π Summary
Problem:
- Variable-length sequences can't be stacked into batches
- Training crashes at first batch
Solution:
- Custom
collate_batch()function - Dynamically pads to max length in each batch
- Preserves attention masks and all labels
Result:
- β Training proceeds normally
- β Efficient memory usage
- β No data loss
- β Proper attention mask handling
π Files Modified
trainer.py- Added collate_batch function (45 lines)- Lines 18-61:
collate_batch()function - Line 202: Added
collate_fn=collate_batch
- Lines 18-61:
test_collate_batch.py- Verification test (180 lines)doc/TRAINING_FIX_PADDING.md- This documentation
β Status
- Identified root cause (variable-length tensors)
- Implemented custom collate function
- Added to DataLoader creation
- Created comprehensive tests
- Documented solution
- READY TO TRAIN π
Next: Run python3 train.py - training should now progress through epochs! π