Spaces:
Sleeping
Sleeping
Fix tensor dimension mismatch by disabling MLM pretrain for demo
Browse files- MLM_PRETRAIN_NOTE.md +71 -0
- app.py +1 -1
- simple_inference.py +10 -21
MLM_PRETRAIN_NOTE.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MLM Pretrain Dimension Mismatch Issue
|
| 2 |
+
|
| 3 |
+
## Problem Description
|
| 4 |
+
When `use_MLM_pretrain = True`, the model encounters a tensor dimension mismatch error:
|
| 5 |
+
```
|
| 6 |
+
The size of tensor a (56) must match the size of tensor b (55) at non-singleton dimension 1
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
## Root Cause Analysis
|
| 10 |
+
|
| 11 |
+
The issue occurs due to the following sequence of operations in `Network.forward()`:
|
| 12 |
+
|
| 13 |
+
1. **MLM Pretrain Processing (if enabled):**
|
| 14 |
+
- `MLMTransformerPretrain.forward(text_dict)` is called with original text length N
|
| 15 |
+
- Inside, it creates embeddings and attention masks for length N
|
| 16 |
+
- Returns `text_emb_src` with shape `[batch, N, embed_dim]`
|
| 17 |
+
|
| 18 |
+
2. **Diagram Concatenation:**
|
| 19 |
+
- `diagram_emb_src` is created with shape `[batch, 1, embed_dim]`
|
| 20 |
+
- These are concatenated: `all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)`
|
| 21 |
+
- Result has shape `[batch, N+1, embed_dim]`
|
| 22 |
+
|
| 23 |
+
3. **Length Adjustment:**
|
| 24 |
+
- `text_dict['len'] += 1` (now N+1)
|
| 25 |
+
- `var_dict['pos'] += 1`
|
| 26 |
+
|
| 27 |
+
4. **Issue:**
|
| 28 |
+
- The MLM pretrain's TransformerEncoder has already created internal states (position embeddings, attention masks) for length N
|
| 29 |
+
- But the actual sequence now has length N+1
|
| 30 |
+
- This causes dimension mismatches in subsequent operations
|
| 31 |
+
|
| 32 |
+
## Solution for Demo
|
| 33 |
+
|
| 34 |
+
For the demo, we've disabled MLM pretrain by setting `use_MLM_pretrain = False` in the Config class. This uses the simpler embedding path that properly handles the dimension adjustments.
|
| 35 |
+
|
| 36 |
+
## Alternative Solutions (if MLM pretrain is needed)
|
| 37 |
+
|
| 38 |
+
### Option 1: Pre-allocate space for diagram
|
| 39 |
+
Modify the MLM pretrain path to account for the diagram token from the start:
|
| 40 |
+
```python
|
| 41 |
+
if self.cfg.use_MLM_pretrain:
|
| 42 |
+
# Increment length before MLM pretrain
|
| 43 |
+
text_dict_copy = text_dict.copy()
|
| 44 |
+
text_dict_copy['len'] = text_dict['len'] + 1
|
| 45 |
+
# Add padding for diagram position
|
| 46 |
+
# ... adjust tokens/tags accordingly
|
| 47 |
+
text_emb_src = self.mlm_pretrain(text_dict_copy)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Option 2: Post-process MLM output
|
| 51 |
+
Recompute position embeddings and masks after concatenation:
|
| 52 |
+
```python
|
| 53 |
+
if self.cfg.use_MLM_pretrain:
|
| 54 |
+
text_emb_src = self.mlm_pretrain(text_dict)
|
| 55 |
+
# After concatenation, reapply position encoding
|
| 56 |
+
all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)
|
| 57 |
+
# Recompute position embeddings for new length
|
| 58 |
+
all_emb_src = self.recompute_positions(all_emb_src, text_dict['len'] + 1)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Option 3: Separate diagram processing
|
| 62 |
+
Process diagram separately and combine at a later stage rather than concatenating embeddings.
|
| 63 |
+
|
| 64 |
+
## Testing
|
| 65 |
+
To verify the fix works:
|
| 66 |
+
1. Upload an image and text to the demo
|
| 67 |
+
2. The model should process without dimension errors
|
| 68 |
+
3. Output should be generated (even if not perfectly accurate without MLM pretrain)
|
| 69 |
+
|
| 70 |
+
## Performance Impact
|
| 71 |
+
Disabling MLM pretrain may reduce model accuracy since the pre-trained language model helps with understanding geometric relationships. However, it ensures stable operation for the demo.
|
app.py
CHANGED
|
@@ -43,7 +43,7 @@ class Config:
|
|
| 43 |
# General
|
| 44 |
self.dropout_rate = 0.2
|
| 45 |
self.beam_size = 10
|
| 46 |
-
self.use_MLM_pretrain =
|
| 47 |
self.MLM_pretrain_path = './LM_MODEL.pth'
|
| 48 |
self.pretrain_emb_path = ''
|
| 49 |
|
|
|
|
| 43 |
# General
|
| 44 |
self.dropout_rate = 0.2
|
| 45 |
self.beam_size = 10
|
| 46 |
+
self.use_MLM_pretrain = False # Disabled due to dimension mismatch issues in demo
|
| 47 |
self.MLM_pretrain_path = './LM_MODEL.pth'
|
| 48 |
self.pretrain_emb_path = ''
|
| 49 |
|
simple_inference.py
CHANGED
|
@@ -50,27 +50,16 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
|
| 50 |
# For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len]
|
| 51 |
# This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim]
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
}
|
| 64 |
-
else:
|
| 65 |
-
# Non-MLM path also needs 3D tensor for consistency
|
| 66 |
-
token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
|
| 67 |
-
|
| 68 |
-
text_dict = {
|
| 69 |
-
'token': token_tensor_3d,
|
| 70 |
-
'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
|
| 71 |
-
'class_tag': torch.LongTensor([class_tag_indices]).to(device),
|
| 72 |
-
'len': torch.LongTensor([text_len]).to(device)
|
| 73 |
-
}
|
| 74 |
|
| 75 |
# Simple var dict (no variables detected)
|
| 76 |
# Note: var positions need to account for the diagram token that will be added
|
|
|
|
| 50 |
# For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len]
|
| 51 |
# This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim]
|
| 52 |
|
| 53 |
+
# Create 3D tensor: [batch_size, 1, text_len]
|
| 54 |
+
# Each token is a single subword, so middle dimension is 1
|
| 55 |
+
token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
|
| 56 |
+
|
| 57 |
+
text_dict = {
|
| 58 |
+
'token': token_tensor_3d,
|
| 59 |
+
'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
|
| 60 |
+
'class_tag': torch.LongTensor([class_tag_indices]).to(device),
|
| 61 |
+
'len': torch.LongTensor([text_len]).to(device)
|
| 62 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# Simple var dict (no variables detected)
|
| 65 |
# Note: var positions need to account for the diagram token that will be added
|