Spaces:
Sleeping
Sleeping
Fix tensor dimension mismatch in MLM pretrain path
Browse files- simple_inference.py +29 -6
simple_inference.py
CHANGED
|
@@ -38,14 +38,37 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
|
| 38 |
batch_size = 1
|
| 39 |
text_len = len(text_indices)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Simple var dict (no variables detected)
|
|
|
|
| 49 |
var_dict = {
|
| 50 |
'pos': torch.zeros(batch_size, 1, dtype=torch.long).to(device),
|
| 51 |
'len': torch.zeros(batch_size, dtype=torch.long).to(device),
|
|
|
|
| 38 |
batch_size = 1
|
| 39 |
text_len = len(text_indices)
|
| 40 |
|
| 41 |
+
# For MLM pretrain, tokens need to be 3D: [batch, seq_len, vocab_size]
|
| 42 |
+
# But here we use 2D: [batch, seq_len] and let the embedding layer handle it
|
| 43 |
+
token_tensor = torch.LongTensor([text_indices]).to(device)
|
| 44 |
+
|
| 45 |
+
# Ensure sect_tag and class_tag match token length
|
| 46 |
+
sect_tag_indices = [1] * text_len # Default to [PROB]
|
| 47 |
+
class_tag_indices = [1] * text_len # Default to [GEN]
|
| 48 |
+
|
| 49 |
+
# MLM pretrain expects token to be [batch, seq_len, num_tokens_per_word]
|
| 50 |
+
# For single word tokens, num_tokens_per_word = 1
|
| 51 |
+
# So we need to add an extra dimension
|
| 52 |
+
if cfg.use_MLM_pretrain:
|
| 53 |
+
# Reshape token tensor to [batch, seq_len, 1] then expand to match expected format
|
| 54 |
+
token_tensor_3d = token_tensor.unsqueeze(-1) # [batch, seq_len, 1]
|
| 55 |
+
|
| 56 |
+
text_dict = {
|
| 57 |
+
'token': token_tensor_3d,
|
| 58 |
+
'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
|
| 59 |
+
'class_tag': torch.LongTensor([class_tag_indices]).to(device),
|
| 60 |
+
'len': torch.LongTensor([text_len]).to(device)
|
| 61 |
+
}
|
| 62 |
+
else:
|
| 63 |
+
text_dict = {
|
| 64 |
+
'token': token_tensor,
|
| 65 |
+
'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
|
| 66 |
+
'class_tag': torch.LongTensor([class_tag_indices]).to(device),
|
| 67 |
+
'len': torch.LongTensor([text_len]).to(device)
|
| 68 |
+
}
|
| 69 |
|
| 70 |
# Simple var dict (no variables detected)
|
| 71 |
+
# Note: var positions need to account for the diagram token that will be added
|
| 72 |
var_dict = {
|
| 73 |
'pos': torch.zeros(batch_size, 1, dtype=torch.long).to(device),
|
| 74 |
'len': torch.zeros(batch_size, dtype=torch.long).to(device),
|