| # π§ Training Fix - Variable-Length Sequence Padding | |
| ## β Problem Identified | |
| ```python | |
| RuntimeError: stack expects each tensor to be equal size, | |
| but got [7] at entry 0 and [41] at entry 1 | |
| ``` | |
| **Root Cause:** The DataLoader's default collate function tries to stack tensors of different lengths into a batch, which fails. This happens because different clauses tokenize to different sequence lengths (e.g., 7 tokens vs 41 tokens). | |
| **Location:** Training loop β DataLoader β `torch.stack()` | |
| --- | |
| ## π― Understanding the Problem | |
| ### **What Happens:** | |
| 1. **Tokenization** produces variable-length sequences: | |
| ```python | |
| Clause 1: "The party shall..." β [101, 2023, 102] (length 7) | |
| Clause 2: "This agreement shall be governed by..." β [101, ..., 102] (length 41) | |
| ``` | |
| 2. **Default collate** tries to stack them: | |
| ```python | |
| torch.stack([tensor([...], length=7), tensor([..., length=41)]) | |
| β ERROR: Can't stack different sizes! | |
| ``` | |
| 3. **Training fails** before first batch completes | |
| ### **Why It Matters:** | |
| - BERT/Transformers require **fixed-size batches** | |
| - Each sequence in a batch must have **same length** | |
| - Solution: **Pad shorter sequences** to match longest in batch | |
| --- | |
| ## β Solution Applied | |
| ### **Custom Collate Function** | |
| Added `collate_batch()` function in `trainer.py`: | |
| ```python | |
| def collate_batch(batch): | |
| """ | |
| Custom collate function to handle variable-length sequences. | |
| Pads all sequences to the maximum length in the batch. | |
| """ | |
| # Find max length in this batch | |
| max_len = max(item['input_ids'].size(0) for item in batch) | |
| # Prepare batched tensors | |
| input_ids_batch = [] | |
| attention_mask_batch = [] | |
| risk_labels_batch = [] | |
| severity_scores_batch = [] | |
| importance_scores_batch = [] | |
| for item in batch: | |
| input_ids = item['input_ids'] | |
| attention_mask = item['attention_mask'] | |
| current_len = input_ids.size(0) | |
| # Pad if needed | |
| if current_len < max_len: | |
| padding_len = max_len - current_len | |
| # Pad input_ids with 0 (PAD token) | |
| input_ids = torch.cat([ | |
| input_ids, | |
| torch.zeros(padding_len, dtype=torch.long) | |
| ]) | |
| # Pad attention_mask with 0 (don't attend to padding) | |
| attention_mask = torch.cat([ | |
| attention_mask, | |
| torch.zeros(padding_len, dtype=torch.long) | |
| ]) | |
| input_ids_batch.append(input_ids) | |
| attention_mask_batch.append(attention_mask) | |
| risk_labels_batch.append(item['risk_label']) | |
| severity_scores_batch.append(item['severity_score']) | |
| importance_scores_batch.append(item['importance_score']) | |
| # Stack into batched tensors | |
| return { | |
| 'input_ids': torch.stack(input_ids_batch), | |
| 'attention_mask': torch.stack(attention_mask_batch), | |
| 'risk_label': torch.stack(risk_labels_batch), | |
| 'severity_score': torch.stack(severity_scores_batch), | |
| 'importance_score': torch.stack(importance_scores_batch) | |
| } | |
| ``` | |
| ### **Updated DataLoader Creation** | |
| ```python | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=self.config.batch_size, | |
| shuffle=shuffle, | |
| num_workers=0, | |
| collate_fn=collate_batch # β Custom collate function | |
| ) | |
| ``` | |
| --- | |
| ## π― How It Works | |
| ### **Example Batch:** | |
| **Input:** | |
| ``` | |
| Sample 0: [101, 2023, 2003, 102] length=4 | |
| Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6 | |
| Sample 2: [101, 102] length=2 | |
| ``` | |
| **After Padding (max_len=6):** | |
| ``` | |
| Sample 0: [101, 2023, 2003, 102, 0, 0] length=6 β | |
| Sample 1: [101, 1996, 2172, 3325, 2003, 102] length=6 β | |
| Sample 2: [101, 102, 0, 0, 0, 0] length=6 β | |
| ``` | |
| **Attention Masks:** | |
| ``` | |
| Sample 0: [1, 1, 1, 1, 0, 0] # Don't attend to padding | |
| Sample 1: [1, 1, 1, 1, 1, 1] # Attend to all | |
| Sample 2: [1, 1, 0, 0, 0, 0] # Don't attend to padding | |
| ``` | |
| **Batched Tensor:** | |
| ```python | |
| input_ids.shape = (3, 6) # batch_size=3, max_len=6 | |
| attention_mask.shape = (3, 6) | |
| risk_label.shape = (3,) | |
| severity_score.shape = (3,) | |
| importance_score.shape = (3,) | |
| ``` | |
| --- | |
| ## π Key Features | |
| ### **1. Dynamic Padding** | |
| - Pads to **max length in current batch** (not global max) | |
| - More efficient: batch with short sequences β less padding | |
| - Example: | |
| - Batch 1: lengths [10, 12, 11] β pad to 12 | |
| - Batch 2: lengths [50, 48, 52] β pad to 52 | |
| ### **2. Attention Mask Handling** | |
| - Original tokens: `attention_mask = 1` (attend) | |
| - Padding tokens: `attention_mask = 0` (ignore) | |
| - BERT won't process padding tokens | |
| ### **3. Preserves All Data** | |
| - Risk labels preserved | |
| - Severity scores preserved | |
| - Importance scores preserved | |
| - No data loss, just shape adjustment | |
| --- | |
| ## π§ͺ Verification | |
| ### **Test Script: `test_collate_batch.py`** | |
| Tests 7 scenarios: | |
| 1. β Import collate_batch function | |
| 2. β Create mock batch with variable lengths (4, 6, 2) | |
| 3. β Collate into uniform batch | |
| 4. β Verify output shapes (3, 6) | |
| 5. β Verify padding tokens are 0 | |
| 6. β Verify labels preserved | |
| 7. β Test with actual DataLoader | |
| **Run test:** | |
| ```bash | |
| python3 test_collate_batch.py | |
| ``` | |
| **Expected output:** | |
| ``` | |
| β Step 1: Imports successful | |
| β Step 2: Creating mock batch with variable lengths... | |
| Sample 0: length 4 | |
| Sample 1: length 6 | |
| Sample 2: length 2 | |
| β Step 3: Testing collate_batch()... | |
| π Output shapes: | |
| input_ids: torch.Size([3, 6]) | |
| attention_mask: torch.Size([3, 6]) | |
| β Step 4: Verifying shapes... | |
| β Step 5: Verifying padding... | |
| β Step 6: Verifying labels preserved... | |
| β Step 7: Testing with actual DataLoader... | |
| π ALL TESTS PASSED! | |
| ``` | |
| --- | |
| ## π Performance Impact | |
| ### **Memory Usage:** | |
| **Before (would crash):** | |
| - Can't create batches β | |
| **After:** | |
| - Batch 1 (short clauses): `batch_size Γ 20 tokens` = minimal | |
| - Batch 2 (long clauses): `batch_size Γ 512 tokens` = max | |
| - **Adaptive:** Uses only what's needed per batch | |
| ### **Training Speed:** | |
| - **No slowdown:** Padding is fast (`torch.cat`) | |
| - **Actually faster:** Proper batching enables GPU parallelism | |
| - **Efficient:** Only pads to batch max, not global max | |
| --- | |
| ## π Technical Details | |
| ### **Why Attention Mask Matters:** | |
| ```python | |
| # Without attention mask (wrong): | |
| BERT attends to padding β learns meaningless patterns β | |
| # With attention mask (correct): | |
| BERT ignores padding β learns only from real tokens β | |
| ``` | |
| ### **PyTorch DataLoader Flow:** | |
| ``` | |
| Dataset.__getitem__(idx) | |
| β Returns single sample with variable length | |
| β | |
| collate_fn(batch) | |
| β Pads all samples to same length | |
| β | |
| Batched tensors | |
| β All same size, ready for model | |
| ``` | |
| ### **Padding Token ID:** | |
| - **0** is standard PAD token in BERT | |
| - Doesn't conflict with vocabulary (reserved ID) | |
| - Attention mask ensures it's ignored anyway | |
| --- | |
| ## π Ready to Train | |
| Now training will proceed normally: | |
| ```bash | |
| python3 train.py | |
| ``` | |
| **Expected flow:** | |
| ``` | |
| 1. Load data β | |
| 2. Discover risk patterns (LDA) β | |
| 3. Create datasets β | |
| 4. Create dataloaders with collate_fn β | |
| 5. Start training... | |
| π Epoch 1/5 | |
| Batch 0: input_ids shape = [16, 235] β | |
| Batch 1: input_ids shape = [16, 178] β | |
| Batch 2: input_ids shape = [16, 312] β | |
| ... | |
| ``` | |
| --- | |
| ## π§ Alternative Solutions (Not Used) | |
| ### **1. Pre-pad to Fixed Length** | |
| ```python | |
| # Pad ALL sequences to max_length=512 | |
| # Wasteful: short clauses pad to 512 unnecessarily | |
| ``` | |
| **Why not:** Wastes memory and computation | |
| ### **2. Filter by Length** | |
| ```python | |
| # Only use clauses within certain length range | |
| # Loses data | |
| ``` | |
| **Why not:** Loses valuable training data | |
| ### **3. Bucket Batching** | |
| ```python | |
| # Group similar-length sequences into batches | |
| # More complex | |
| ``` | |
| **Why not:** Our solution is simpler and works well | |
| --- | |
| ## π Summary | |
| ### **Problem:** | |
| - Variable-length sequences can't be stacked into batches | |
| - Training crashes at first batch | |
| ### **Solution:** | |
| - Custom `collate_batch()` function | |
| - Dynamically pads to max length in each batch | |
| - Preserves attention masks and all labels | |
| ### **Result:** | |
| - β Training proceeds normally | |
| - β Efficient memory usage | |
| - β No data loss | |
| - β Proper attention mask handling | |
| --- | |
| ## π Files Modified | |
| 1. **`trainer.py`** - Added collate_batch function (45 lines) | |
| - Lines 18-61: `collate_batch()` function | |
| - Line 202: Added `collate_fn=collate_batch` | |
| 2. **`test_collate_batch.py`** - Verification test (180 lines) | |
| 3. **`doc/TRAINING_FIX_PADDING.md`** - This documentation | |
| --- | |
| ## β Status | |
| - [x] Identified root cause (variable-length tensors) | |
| - [x] Implemented custom collate function | |
| - [x] Added to DataLoader creation | |
| - [x] Created comprehensive tests | |
| - [x] Documented solution | |
| - [x] **READY TO TRAIN** π | |
| --- | |
| **Next:** Run `python3 train.py` - training should now progress through epochs! π | |