pgps-demo / MLM_PRETRAIN_NOTE.md
asdfasdfdsafdsa's picture
Fix tensor dimension mismatch by disabling MLM pretrain for demo
96836c8 verified

A newer version of the Gradio SDK is available: 6.4.0

Upgrade

MLM Pretrain Dimension Mismatch Issue

Problem Description

When use_MLM_pretrain = True, the model encounters a tensor dimension mismatch error:

The size of tensor a (56) must match the size of tensor b (55) at non-singleton dimension 1

Root Cause Analysis

The issue occurs due to the following sequence of operations in Network.forward():

  1. MLM Pretrain Processing (if enabled):

    • MLMTransformerPretrain.forward(text_dict) is called with original text length N
    • Inside, it creates embeddings and attention masks for length N
    • Returns text_emb_src with shape [batch, N, embed_dim]
  2. Diagram Concatenation:

    • diagram_emb_src is created with shape [batch, 1, embed_dim]
    • These are concatenated: all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)
    • Result has shape [batch, N+1, embed_dim]
  3. Length Adjustment:

    • text_dict['len'] += 1 (now N+1)
    • var_dict['pos'] += 1
  4. Issue:

    • The MLM pretrain's TransformerEncoder has already created internal states (position embeddings, attention masks) for length N
    • But the actual sequence now has length N+1
    • This causes dimension mismatches in subsequent operations

Solution for Demo

For the demo, we've disabled MLM pretrain by setting use_MLM_pretrain = False in the Config class. This uses the simpler embedding path that properly handles the dimension adjustments.

Alternative Solutions (if MLM pretrain is needed)

Option 1: Pre-allocate space for diagram

Modify the MLM pretrain path to account for the diagram token from the start:

if self.cfg.use_MLM_pretrain:
    # Increment length before MLM pretrain
    text_dict_copy = text_dict.copy()
    text_dict_copy['len'] = text_dict['len'] + 1
    # Add padding for diagram position
    # ... adjust tokens/tags accordingly
    text_emb_src = self.mlm_pretrain(text_dict_copy)

Option 2: Post-process MLM output

Recompute position embeddings and masks after concatenation:

if self.cfg.use_MLM_pretrain:
    text_emb_src = self.mlm_pretrain(text_dict)
    # After concatenation, reapply position encoding
    all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)
    # Recompute position embeddings for new length
    all_emb_src = self.recompute_positions(all_emb_src, text_dict['len'] + 1)

Option 3: Separate diagram processing

Process diagram separately and combine at a later stage rather than concatenating embeddings.

Testing

To verify the fix works:

  1. Upload an image and text to the demo
  2. The model should process without dimension errors
  3. Output should be generated (even if not perfectly accurate without MLM pretrain)

Performance Impact

Disabling MLM pretrain may reduce model accuracy since the pre-trained language model helps with understanding geometric relationships. However, it ensures stable operation for the demo.