code2-repo / doc /TRAINING_FIX_PADDING.md
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified

πŸ”§ Training Fix - Variable-Length Sequence Padding

❌ Problem Identified

RuntimeError: stack expects each tensor to be equal size, 
but got [7] at entry 0 and [41] at entry 1

Root Cause: The DataLoader's default collate function tries to stack tensors of different lengths into a batch, which fails. This happens because different clauses tokenize to different sequence lengths (e.g., 7 tokens vs 41 tokens).

Location: Training loop β†’ DataLoader β†’ torch.stack()


🎯 Understanding the Problem

What Happens:

  1. Tokenization produces variable-length sequences:

    Clause 1: "The party shall..." β†’ [101, 2023, 102] (length 7)
    Clause 2: "This agreement shall be governed by..." β†’ [101, ..., 102] (length 41)
    
  2. Default collate tries to stack them:

    torch.stack([tensor([...], length=7), tensor([..., length=41)])
    ❌ ERROR: Can't stack different sizes!
    
  3. Training fails before first batch completes

Why It Matters:

  • BERT/Transformers require fixed-size batches
  • Each sequence in a batch must have same length
  • Solution: Pad shorter sequences to match longest in batch

βœ… Solution Applied

Custom Collate Function

Added collate_batch() function in trainer.py:

def collate_batch(batch):
    """
    Custom collate function to handle variable-length sequences.
    Pads all sequences to the maximum length in the batch.
    """
    # Find max length in this batch
    max_len = max(item['input_ids'].size(0) for item in batch)
    
    # Prepare batched tensors
    input_ids_batch = []
    attention_mask_batch = []
    risk_labels_batch = []
    severity_scores_batch = []
    importance_scores_batch = []
    
    for item in batch:
        input_ids = item['input_ids']
        attention_mask = item['attention_mask']
        current_len = input_ids.size(0)
        
        # Pad if needed
        if current_len < max_len:
            padding_len = max_len - current_len
            
            # Pad input_ids with 0 (PAD token)
            input_ids = torch.cat([
                input_ids, 
                torch.zeros(padding_len, dtype=torch.long)
            ])
            
            # Pad attention_mask with 0 (don't attend to padding)
            attention_mask = torch.cat([
                attention_mask, 
                torch.zeros(padding_len, dtype=torch.long)
            ])
        
        input_ids_batch.append(input_ids)
        attention_mask_batch.append(attention_mask)
        risk_labels_batch.append(item['risk_label'])
        severity_scores_batch.append(item['severity_score'])
        importance_scores_batch.append(item['importance_score'])
    
    # Stack into batched tensors
    return {
        'input_ids': torch.stack(input_ids_batch),
        'attention_mask': torch.stack(attention_mask_batch),
        'risk_label': torch.stack(risk_labels_batch),
        'severity_score': torch.stack(severity_scores_batch),
        'importance_score': torch.stack(importance_scores_batch)
    }

Updated DataLoader Creation

dataloader = DataLoader(
    dataset,
    batch_size=self.config.batch_size,
    shuffle=shuffle,
    num_workers=0,
    collate_fn=collate_batch  # βœ… Custom collate function
)

🎯 How It Works

Example Batch:

Input:

Sample 0: [101, 2023, 2003, 102]           length=4
Sample 1: [101, 1996, 2172, 3325, 2003, 102]  length=6
Sample 2: [101, 102]                       length=2

After Padding (max_len=6):

Sample 0: [101, 2023, 2003, 102, 0, 0]    length=6 βœ…
Sample 1: [101, 1996, 2172, 3325, 2003, 102]  length=6 βœ…
Sample 2: [101, 102, 0, 0, 0, 0]          length=6 βœ…

Attention Masks:

Sample 0: [1, 1, 1, 1, 0, 0]    # Don't attend to padding
Sample 1: [1, 1, 1, 1, 1, 1]    # Attend to all
Sample 2: [1, 1, 0, 0, 0, 0]    # Don't attend to padding

Batched Tensor:

input_ids.shape = (3, 6)         # batch_size=3, max_len=6
attention_mask.shape = (3, 6)
risk_label.shape = (3,)
severity_score.shape = (3,)
importance_score.shape = (3,)

πŸ” Key Features

1. Dynamic Padding

  • Pads to max length in current batch (not global max)
  • More efficient: batch with short sequences β†’ less padding
  • Example:
    • Batch 1: lengths [10, 12, 11] β†’ pad to 12
    • Batch 2: lengths [50, 48, 52] β†’ pad to 52

2. Attention Mask Handling

  • Original tokens: attention_mask = 1 (attend)
  • Padding tokens: attention_mask = 0 (ignore)
  • BERT won't process padding tokens

3. Preserves All Data

  • Risk labels preserved
  • Severity scores preserved
  • Importance scores preserved
  • No data loss, just shape adjustment

πŸ§ͺ Verification

Test Script: test_collate_batch.py

Tests 7 scenarios:

  1. βœ… Import collate_batch function
  2. βœ… Create mock batch with variable lengths (4, 6, 2)
  3. βœ… Collate into uniform batch
  4. βœ… Verify output shapes (3, 6)
  5. βœ… Verify padding tokens are 0
  6. βœ… Verify labels preserved
  7. βœ… Test with actual DataLoader

Run test:

python3 test_collate_batch.py

Expected output:

βœ… Step 1: Imports successful
βœ… Step 2: Creating mock batch with variable lengths...
   Sample 0: length 4
   Sample 1: length 6
   Sample 2: length 2
βœ… Step 3: Testing collate_batch()...
   πŸ“Š Output shapes:
      input_ids: torch.Size([3, 6])
      attention_mask: torch.Size([3, 6])
βœ… Step 4: Verifying shapes...
βœ… Step 5: Verifying padding...
βœ… Step 6: Verifying labels preserved...
βœ… Step 7: Testing with actual DataLoader...

πŸŽ‰ ALL TESTS PASSED!

πŸ“Š Performance Impact

Memory Usage:

Before (would crash):

  • Can't create batches ❌

After:

  • Batch 1 (short clauses): batch_size Γ— 20 tokens = minimal
  • Batch 2 (long clauses): batch_size Γ— 512 tokens = max
  • Adaptive: Uses only what's needed per batch

Training Speed:

  • No slowdown: Padding is fast (torch.cat)
  • Actually faster: Proper batching enables GPU parallelism
  • Efficient: Only pads to batch max, not global max

πŸŽ“ Technical Details

Why Attention Mask Matters:

# Without attention mask (wrong):
BERT attends to padding β†’ learns meaningless patterns ❌

# With attention mask (correct):
BERT ignores padding β†’ learns only from real tokens βœ…

PyTorch DataLoader Flow:

Dataset.__getitem__(idx) 
  β†’ Returns single sample with variable length
  ↓
collate_fn(batch)
  β†’ Pads all samples to same length
  ↓
Batched tensors
  β†’ All same size, ready for model

Padding Token ID:

  • 0 is standard PAD token in BERT
  • Doesn't conflict with vocabulary (reserved ID)
  • Attention mask ensures it's ignored anyway

πŸš€ Ready to Train

Now training will proceed normally:

python3 train.py

Expected flow:

1. Load data βœ…
2. Discover risk patterns (LDA) βœ…
3. Create datasets βœ…
4. Create dataloaders with collate_fn βœ…
5. Start training...
   πŸ“ˆ Epoch 1/5
   Batch 0: input_ids shape = [16, 235] βœ…
   Batch 1: input_ids shape = [16, 178] βœ…
   Batch 2: input_ids shape = [16, 312] βœ…
   ...

πŸ”§ Alternative Solutions (Not Used)

1. Pre-pad to Fixed Length

# Pad ALL sequences to max_length=512
# Wasteful: short clauses pad to 512 unnecessarily

Why not: Wastes memory and computation

2. Filter by Length

# Only use clauses within certain length range
# Loses data

Why not: Loses valuable training data

3. Bucket Batching

# Group similar-length sequences into batches
# More complex

Why not: Our solution is simpler and works well


πŸ“ Summary

Problem:

  • Variable-length sequences can't be stacked into batches
  • Training crashes at first batch

Solution:

  • Custom collate_batch() function
  • Dynamically pads to max length in each batch
  • Preserves attention masks and all labels

Result:

  • βœ… Training proceeds normally
  • βœ… Efficient memory usage
  • βœ… No data loss
  • βœ… Proper attention mask handling

πŸ“š Files Modified

  1. trainer.py - Added collate_batch function (45 lines)

    • Lines 18-61: collate_batch() function
    • Line 202: Added collate_fn=collate_batch
  2. test_collate_batch.py - Verification test (180 lines)

  3. doc/TRAINING_FIX_PADDING.md - This documentation


βœ… Status

  • Identified root cause (variable-length tensors)
  • Implemented custom collate function
  • Added to DataLoader creation
  • Created comprehensive tests
  • Documented solution
  • READY TO TRAIN πŸŽ‰

Next: Run python3 train.py - training should now progress through epochs! πŸš€