# ๐Ÿ”ง Training Fix - Variable-Length Sequence Padding ## โŒ Problem Identified ```python RuntimeError: stack expects each tensor to be equal size, but got [7] at entry 0 and [41] at entry 1 ``` **Root Cause:** The DataLoader's default collate function tries to stack tensors of different lengths into a batch, which fails. This happens because different clauses tokenize to different sequence lengths (e.g., 7 tokens vs 41 tokens). **Location:** Training loop โ†’ DataLoader โ†’ `torch.stack()` --- ## ๐ŸŽฏ Understanding the Problem ### **What Happens:** 1. **Tokenization** produces variable-length sequences: ```python Clause 1: "The party shall..." โ†’ [101, 2023, 102] (length 7) Clause 2: "This agreement shall be governed by..." โ†’ [101, ..., 102] (length 41) ``` 2. **Default collate** tries to stack them: ```python torch.stack([tensor([...], length=7), tensor([..., length=41)]) โŒ ERROR: Can't stack different sizes! ``` 3. **Training fails** before first batch completes ### **Why It Matters:** - BERT/Transformers require **fixed-size batches** - Each sequence in a batch must have **same length** - Solution: **Pad shorter sequences** to match longest in batch --- ## โœ… Solution Applied ### **Custom Collate Function** Added `collate_batch()` function in `trainer.py`: ```python def collate_batch(batch): """ Custom collate function to handle variable-length sequences. Pads all sequences to the maximum length in the batch. """ # Find max length in this batch max_len = max(item['input_ids'].size(0) for item in batch) # Prepare batched tensors input_ids_batch = [] attention_mask_batch = [] risk_labels_batch = [] severity_scores_batch = [] importance_scores_batch = [] for item in batch: input_ids = item['input_ids'] attention_mask = item['attention_mask'] current_len = input_ids.size(0) # Pad if needed if current_len < max_len: padding_len = max_len - current_len # Pad input_ids with 0 (PAD token) input_ids = torch.cat([ input_ids, torch.zeros(padding_len, dtype=torch.long) ]) # Pad attention_mask with 0 (don't attend to padding) attention_mask = torch.cat([ attention_mask, torch.zeros(padding_len, dtype=torch.long) ]) input_ids_batch.append(input_ids) attention_mask_batch.append(attention_mask) risk_labels_batch.append(item['risk_label']) severity_scores_batch.append(item['severity_score']) importance_scores_batch.append(item['importance_score']) # Stack into batched tensors return { 'input_ids': torch.stack(input_ids_batch), 'attention_mask': torch.stack(attention_mask_batch), 'risk_label': torch.stack(risk_labels_batch), 'severity_score': torch.stack(severity_scores_batch), 'importance_score': torch.stack(importance_scores_batch) } ``` ### **Updated DataLoader Creation** ```python dataloader = DataLoader( dataset, batch_size=self.config.batch_size, shuffle=shuffle, num_workers=0, collate_fn=collate_batch # โœ… Custom collate function ) ``` --- ## ๐ŸŽฏ How It Works ### **Example Batch:** **Input:** ``` Sample 0: [101, 2023, 2003, 102] length=4 Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6 Sample 2: [101, 102] length=2 ``` **After Padding (max_len=6):** ``` Sample 0: [101, 2023, 2003, 102, 0, 0] length=6 โœ… Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6 โœ… Sample 2: [101, 102, 0, 0, 0, 0] length=6 โœ… ``` **Attention Masks:** ``` Sample 0: [1, 1, 1, 1, 0, 0] # Don't attend to padding Sample 1: [1, 1, 1, 1, 1, 1] # Attend to all Sample 2: [1, 1, 0, 0, 0, 0] # Don't attend to padding ``` **Batched Tensor:** ```python input_ids.shape = (3, 6) # batch_size=3, max_len=6 attention_mask.shape = (3, 6) risk_label.shape = (3,) severity_score.shape = (3,) importance_score.shape = (3,) ``` --- ## ๐Ÿ” Key Features ### **1. Dynamic Padding** - Pads to **max length in current batch** (not global max) - More efficient: batch with short sequences โ†’ less padding - Example: - Batch 1: lengths [10, 12, 11] โ†’ pad to 12 - Batch 2: lengths [50, 48, 52] โ†’ pad to 52 ### **2. Attention Mask Handling** - Original tokens: `attention_mask = 1` (attend) - Padding tokens: `attention_mask = 0` (ignore) - BERT won't process padding tokens ### **3. Preserves All Data** - Risk labels preserved - Severity scores preserved - Importance scores preserved - No data loss, just shape adjustment --- ## ๐Ÿงช Verification ### **Test Script: `test_collate_batch.py`** Tests 7 scenarios: 1. โœ… Import collate_batch function 2. โœ… Create mock batch with variable lengths (4, 6, 2) 3. โœ… Collate into uniform batch 4. โœ… Verify output shapes (3, 6) 5. โœ… Verify padding tokens are 0 6. โœ… Verify labels preserved 7. โœ… Test with actual DataLoader **Run test:** ```bash python3 test_collate_batch.py ``` **Expected output:** ``` โœ… Step 1: Imports successful โœ… Step 2: Creating mock batch with variable lengths... Sample 0: length 4 Sample 1: length 6 Sample 2: length 2 โœ… Step 3: Testing collate_batch()... ๐Ÿ“Š Output shapes: input_ids: torch.Size([3, 6]) attention_mask: torch.Size([3, 6]) โœ… Step 4: Verifying shapes... โœ… Step 5: Verifying padding... โœ… Step 6: Verifying labels preserved... โœ… Step 7: Testing with actual DataLoader... ๐ŸŽ‰ ALL TESTS PASSED! ``` --- ## ๐Ÿ“Š Performance Impact ### **Memory Usage:** **Before (would crash):** - Can't create batches โŒ **After:** - Batch 1 (short clauses): `batch_size ร— 20 tokens` = minimal - Batch 2 (long clauses): `batch_size ร— 512 tokens` = max - **Adaptive:** Uses only what's needed per batch ### **Training Speed:** - **No slowdown:** Padding is fast (`torch.cat`) - **Actually faster:** Proper batching enables GPU parallelism - **Efficient:** Only pads to batch max, not global max --- ## ๐ŸŽ“ Technical Details ### **Why Attention Mask Matters:** ```python # Without attention mask (wrong): BERT attends to padding โ†’ learns meaningless patterns โŒ # With attention mask (correct): BERT ignores padding โ†’ learns only from real tokens โœ… ``` ### **PyTorch DataLoader Flow:** ``` Dataset.__getitem__(idx) โ†’ Returns single sample with variable length โ†“ collate_fn(batch) โ†’ Pads all samples to same length โ†“ Batched tensors โ†’ All same size, ready for model ``` ### **Padding Token ID:** - **0** is standard PAD token in BERT - Doesn't conflict with vocabulary (reserved ID) - Attention mask ensures it's ignored anyway --- ## ๐Ÿš€ Ready to Train Now training will proceed normally: ```bash python3 train.py ``` **Expected flow:** ``` 1. Load data โœ… 2. Discover risk patterns (LDA) โœ… 3. Create datasets โœ… 4. Create dataloaders with collate_fn โœ… 5. Start training... ๐Ÿ“ˆ Epoch 1/5 Batch 0: input_ids shape = [16, 235] โœ… Batch 1: input_ids shape = [16, 178] โœ… Batch 2: input_ids shape = [16, 312] โœ… ... ``` --- ## ๐Ÿ”ง Alternative Solutions (Not Used) ### **1. Pre-pad to Fixed Length** ```python # Pad ALL sequences to max_length=512 # Wasteful: short clauses pad to 512 unnecessarily ``` **Why not:** Wastes memory and computation ### **2. Filter by Length** ```python # Only use clauses within certain length range # Loses data ``` **Why not:** Loses valuable training data ### **3. Bucket Batching** ```python # Group similar-length sequences into batches # More complex ``` **Why not:** Our solution is simpler and works well --- ## ๐Ÿ“ Summary ### **Problem:** - Variable-length sequences can't be stacked into batches - Training crashes at first batch ### **Solution:** - Custom `collate_batch()` function - Dynamically pads to max length in each batch - Preserves attention masks and all labels ### **Result:** - โœ… Training proceeds normally - โœ… Efficient memory usage - โœ… No data loss - โœ… Proper attention mask handling --- ## ๐Ÿ“š Files Modified 1. **`trainer.py`** - Added collate_batch function (45 lines) - Lines 18-61: `collate_batch()` function - Line 202: Added `collate_fn=collate_batch` 2. **`test_collate_batch.py`** - Verification test (180 lines) 3. **`doc/TRAINING_FIX_PADDING.md`** - This documentation --- ## โœ… Status - [x] Identified root cause (variable-length tensors) - [x] Implemented custom collate function - [x] Added to DataLoader creation - [x] Created comprehensive tests - [x] Documented solution - [x] **READY TO TRAIN** ๐ŸŽ‰ --- **Next:** Run `python3 train.py` - training should now progress through epochs! ๐Ÿš€