asdfasdfdsafdsa commited on
Commit
aef03a7
·
verified ·
1 Parent(s): bf8c161

Fix tensor dimension mismatch in MLM pretrain path

Browse files
Files changed (1) hide show
  1. simple_inference.py +29 -6
simple_inference.py CHANGED
@@ -38,14 +38,37 @@ def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
38
  batch_size = 1
39
  text_len = len(text_indices)
40
 
41
- text_dict = {
42
- 'token': torch.LongTensor([text_indices]).to(device),
43
- 'sect_tag': torch.ones(batch_size, text_len, dtype=torch.long).to(device),
44
- 'class_tag': torch.ones(batch_size, text_len, dtype=torch.long).to(device),
45
- 'len': torch.LongTensor([text_len]).to(device)
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Simple var dict (no variables detected)
 
49
  var_dict = {
50
  'pos': torch.zeros(batch_size, 1, dtype=torch.long).to(device),
51
  'len': torch.zeros(batch_size, dtype=torch.long).to(device),
 
38
  batch_size = 1
39
  text_len = len(text_indices)
40
 
41
+ # For MLM pretrain, tokens need to be 3D: [batch, seq_len, vocab_size]
42
+ # But here we use 2D: [batch, seq_len] and let the embedding layer handle it
43
+ token_tensor = torch.LongTensor([text_indices]).to(device)
44
+
45
+ # Ensure sect_tag and class_tag match token length
46
+ sect_tag_indices = [1] * text_len # Default to [PROB]
47
+ class_tag_indices = [1] * text_len # Default to [GEN]
48
+
49
+ # MLM pretrain expects token to be [batch, seq_len, num_tokens_per_word]
50
+ # For single word tokens, num_tokens_per_word = 1
51
+ # So we need to add an extra dimension
52
+ if cfg.use_MLM_pretrain:
53
+ # Reshape token tensor to [batch, seq_len, 1] then expand to match expected format
54
+ token_tensor_3d = token_tensor.unsqueeze(-1) # [batch, seq_len, 1]
55
+
56
+ text_dict = {
57
+ 'token': token_tensor_3d,
58
+ 'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
59
+ 'class_tag': torch.LongTensor([class_tag_indices]).to(device),
60
+ 'len': torch.LongTensor([text_len]).to(device)
61
+ }
62
+ else:
63
+ text_dict = {
64
+ 'token': token_tensor,
65
+ 'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
66
+ 'class_tag': torch.LongTensor([class_tag_indices]).to(device),
67
+ 'len': torch.LongTensor([text_len]).to(device)
68
+ }
69
 
70
  # Simple var dict (no variables detected)
71
+ # Note: var positions need to account for the diagram token that will be added
72
  var_dict = {
73
  'pos': torch.zeros(batch_size, 1, dtype=torch.long).to(device),
74
  'len': torch.zeros(batch_size, dtype=torch.long).to(device),