asdfasdfdsafdsa commited on
Commit
96836c8
·
verified ·
1 Parent(s): cc4d3bc

Fix tensor dimension mismatch by disabling MLM pretrain for demo

Browse files
Files changed (3) hide show
  1. MLM_PRETRAIN_NOTE.md +71 -0
  2. app.py +1 -1
  3. simple_inference.py +10 -21
MLM_PRETRAIN_NOTE.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MLM Pretrain Dimension Mismatch Issue
2
+
3
+ ## Problem Description
4
+ When `use_MLM_pretrain = True`, the model encounters a tensor dimension mismatch error:
5
+ ```
6
+ The size of tensor a (56) must match the size of tensor b (55) at non-singleton dimension 1
7
+ ```
8
+
9
+ ## Root Cause Analysis
10
+
11
+ The issue occurs due to the following sequence of operations in `Network.forward()`:
12
+
13
+ 1. **MLM Pretrain Processing (if enabled):**
14
+ - `MLMTransformerPretrain.forward(text_dict)` is called with original text length N
15
+ - Inside, it creates embeddings and attention masks for length N
16
+ - Returns `text_emb_src` with shape `[batch, N, embed_dim]`
17
+
18
+ 2. **Diagram Concatenation:**
19
+ - `diagram_emb_src` is created with shape `[batch, 1, embed_dim]`
20
+ - These are concatenated: `all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)`
21
+ - Result has shape `[batch, N+1, embed_dim]`
22
+
23
+ 3. **Length Adjustment:**
24
+ - `text_dict['len'] += 1` (now N+1)
25
+ - `var_dict['pos'] += 1`
26
+
27
+ 4. **Issue:**
28
+ - The MLM pretrain's TransformerEncoder has already created internal states (position embeddings, attention masks) for length N
29
+ - But the actual sequence now has length N+1
30
+ - This causes dimension mismatches in subsequent operations
31
+
32
+ ## Solution for Demo
33
+
34
+ For the demo, we've disabled MLM pretrain by setting `use_MLM_pretrain = False` in the Config class. This uses the simpler embedding path that properly handles the dimension adjustments.
35
+
36
+ ## Alternative Solutions (if MLM pretrain is needed)
37
+
38
+ ### Option 1: Pre-allocate space for diagram
39
+ Modify the MLM pretrain path to account for the diagram token from the start:
40
+ ```python
41
+ if self.cfg.use_MLM_pretrain:
42
+ # Increment length before MLM pretrain
43
+ text_dict_copy = text_dict.copy()
44
+ text_dict_copy['len'] = text_dict['len'] + 1
45
+ # Add padding for diagram position
46
+ # ... adjust tokens/tags accordingly
47
+ text_emb_src = self.mlm_pretrain(text_dict_copy)
48
+ ```
49
+
50
+ ### Option 2: Post-process MLM output
51
+ Recompute position embeddings and masks after concatenation:
52
+ ```python
53
+ if self.cfg.use_MLM_pretrain:
54
+ text_emb_src = self.mlm_pretrain(text_dict)
55
+ # After concatenation, reapply position encoding
56
+ all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)
57
+ # Recompute position embeddings for new length
58
+ all_emb_src = self.recompute_positions(all_emb_src, text_dict['len'] + 1)
59
+ ```
60
+
61
+ ### Option 3: Separate diagram processing
62
+ Process diagram separately and combine at a later stage rather than concatenating embeddings.
63
+
64
+ ## Testing
65
+ To verify the fix works:
66
+ 1. Upload an image and text to the demo
67
+ 2. The model should process without dimension errors
68
+ 3. Output should be generated (even if not perfectly accurate without MLM pretrain)
69
+
70
+ ## Performance Impact
71
+ Disabling MLM pretrain may reduce model accuracy since the pre-trained language model helps with understanding geometric relationships. However, it ensures stable operation for the demo.
app.py CHANGED
@@ -43,7 +43,7 @@ class Config:
43
  # General
44
  self.dropout_rate = 0.2
45
  self.beam_size = 10
46
- self.use_MLM_pretrain = True
47
  self.MLM_pretrain_path = './LM_MODEL.pth'
48
  self.pretrain_emb_path = ''
49
 
 
43
  # General
44
  self.dropout_rate = 0.2
45
  self.beam_size = 10
46
+ self.use_MLM_pretrain = False # Disabled due to dimension mismatch issues in demo
47
  self.MLM_pretrain_path = './LM_MODEL.pth'
48
  self.pretrain_emb_path = ''
49
 
simple_inference.py CHANGED
@@ -50,27 +50,16 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
50
  # For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len]
51
  # This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim]
52
 
53
- if cfg.use_MLM_pretrain:
54
- # Create 3D tensor: [batch_size, 1, text_len]
55
- # Each token is a single subword, so middle dimension is 1
56
- token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
57
-
58
- text_dict = {
59
- 'token': token_tensor_3d,
60
- 'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
61
- 'class_tag': torch.LongTensor([class_tag_indices]).to(device),
62
- 'len': torch.LongTensor([text_len]).to(device)
63
- }
64
- else:
65
- # Non-MLM path also needs 3D tensor for consistency
66
- token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
67
-
68
- text_dict = {
69
- 'token': token_tensor_3d,
70
- 'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
71
- 'class_tag': torch.LongTensor([class_tag_indices]).to(device),
72
- 'len': torch.LongTensor([text_len]).to(device)
73
- }
74
 
75
  # Simple var dict (no variables detected)
76
  # Note: var positions need to account for the diagram token that will be added
 
50
  # For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len]
51
  # This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim]
52
 
53
+ # Create 3D tensor: [batch_size, 1, text_len]
54
+ # Each token is a single subword, so middle dimension is 1
55
+ token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
56
+
57
+ text_dict = {
58
+ 'token': token_tensor_3d,
59
+ 'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
60
+ 'class_tag': torch.LongTensor([class_tag_indices]).to(device),
61
+ 'len': torch.LongTensor([text_len]).to(device)
62
+ }
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # Simple var dict (no variables detected)
65
  # Note: var positions need to account for the diagram token that will be added