Spaces:
Sleeping
Sleeping
Fix token tensor dimensions - should be [batch, 1, seq_len]
Browse files- simple_inference.py +11 -6
simple_inference.py
CHANGED
|
@@ -46,12 +46,14 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
|
| 46 |
sect_tag_indices = [1] * text_len # Default to [PROB]
|
| 47 |
class_tag_indices = [1] * text_len # Default to [GEN]
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
# For
|
| 51 |
-
#
|
|
|
|
| 52 |
if cfg.use_MLM_pretrain:
|
| 53 |
-
#
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
text_dict = {
|
| 57 |
'token': token_tensor_3d,
|
|
@@ -60,8 +62,11 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
|
| 60 |
'len': torch.LongTensor([text_len]).to(device)
|
| 61 |
}
|
| 62 |
else:
|
|
|
|
|
|
|
|
|
|
| 63 |
text_dict = {
|
| 64 |
-
'token':
|
| 65 |
'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
|
| 66 |
'class_tag': torch.LongTensor([class_tag_indices]).to(device),
|
| 67 |
'len': torch.LongTensor([text_len]).to(device)
|
|
|
|
| 46 |
sect_tag_indices = [1] * text_len # Default to [PROB]
|
| 47 |
class_tag_indices = [1] * text_len # Default to [GEN]
|
| 48 |
|
| 49 |
+
# The model expects token to be [batch, num_subwords_per_token, seq_len]
|
| 50 |
+
# For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len]
|
| 51 |
+
# This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim]
|
| 52 |
+
|
| 53 |
if cfg.use_MLM_pretrain:
|
| 54 |
+
# Create 3D tensor: [batch_size, 1, text_len]
|
| 55 |
+
# Each token is a single subword, so middle dimension is 1
|
| 56 |
+
token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
|
| 57 |
|
| 58 |
text_dict = {
|
| 59 |
'token': token_tensor_3d,
|
|
|
|
| 62 |
'len': torch.LongTensor([text_len]).to(device)
|
| 63 |
}
|
| 64 |
else:
|
| 65 |
+
# Non-MLM path also needs 3D tensor for consistency
|
| 66 |
+
token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
|
| 67 |
+
|
| 68 |
text_dict = {
|
| 69 |
+
'token': token_tensor_3d,
|
| 70 |
'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
|
| 71 |
'class_tag': torch.LongTensor([class_tag_indices]).to(device),
|
| 72 |
'len': torch.LongTensor([text_len]).to(device)
|