asdfasdfdsafdsa commited on
Commit
cc4d3bc
·
verified ·
1 Parent(s): aef03a7

Fix token tensor dimensions - should be [batch, 1, seq_len]

Browse files
Files changed (1) hide show
  1. simple_inference.py +11 -6
simple_inference.py CHANGED
@@ -46,12 +46,14 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
46
  sect_tag_indices = [1] * text_len # Default to [PROB]
47
  class_tag_indices = [1] * text_len # Default to [GEN]
48
 
49
- # MLM pretrain expects token to be [batch, seq_len, num_tokens_per_word]
50
- # For single word tokens, num_tokens_per_word = 1
51
- # So we need to add an extra dimension
 
52
  if cfg.use_MLM_pretrain:
53
- # Reshape token tensor to [batch, seq_len, 1] then expand to match expected format
54
- token_tensor_3d = token_tensor.unsqueeze(-1) # [batch, seq_len, 1]
 
55
 
56
  text_dict = {
57
  'token': token_tensor_3d,
@@ -60,8 +62,11 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
60
  'len': torch.LongTensor([text_len]).to(device)
61
  }
62
  else:
 
 
 
63
  text_dict = {
64
- 'token': token_tensor,
65
  'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
66
  'class_tag': torch.LongTensor([class_tag_indices]).to(device),
67
  'len': torch.LongTensor([text_len]).to(device)
 
46
  sect_tag_indices = [1] * text_len # Default to [PROB]
47
  class_tag_indices = [1] * text_len # Default to [GEN]
48
 
49
+ # The model expects token to be [batch, num_subwords_per_token, seq_len]
50
+ # For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len]
51
+ # This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim]
52
+
53
  if cfg.use_MLM_pretrain:
54
+ # Create 3D tensor: [batch_size, 1, text_len]
55
+ # Each token is a single subword, so middle dimension is 1
56
+ token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
57
 
58
  text_dict = {
59
  'token': token_tensor_3d,
 
62
  'len': torch.LongTensor([text_len]).to(device)
63
  }
64
  else:
65
+ # Non-MLM path also needs 3D tensor for consistency
66
+ token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
67
+
68
  text_dict = {
69
+ 'token': token_tensor_3d,
70
  'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
71
  'class_tag': torch.LongTensor([class_tag_indices]).to(device),
72
  'len': torch.LongTensor([text_len]).to(device)