damerajee commited on
Commit
d9d60c5
·
verified ·
1 Parent(s): 96e263a

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +2 -1
modeling_gpt2vision.py CHANGED
@@ -87,8 +87,9 @@ class GPT2Vision(PreTrainedModel):
87
  attention_mask = batch['attention_mask'].to(self.device)
88
 
89
  img_embs = self.vision_encoder(images, device=self.device)
 
90
  img_embs = self.mlp(img_embs)
91
-
92
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
93
 
94
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
 
87
  attention_mask = batch['attention_mask'].to(self.device)
88
 
89
  img_embs = self.vision_encoder(images, device=self.device)
90
+ print("img_embs",img_embs.size)
91
  img_embs = self.mlp(img_embs)
92
+ print("img_embs",img_embs.size)
93
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
94
 
95
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)