Update modeling_gpt2vision.py
Browse files- modeling_gpt2vision.py +2 -1
modeling_gpt2vision.py
CHANGED
|
@@ -87,8 +87,9 @@ class GPT2Vision(PreTrainedModel):
|
|
| 87 |
attention_mask = batch['attention_mask'].to(self.device)
|
| 88 |
|
| 89 |
img_embs = self.vision_encoder(images, device=self.device)
|
|
|
|
| 90 |
img_embs = self.mlp(img_embs)
|
| 91 |
-
|
| 92 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
| 93 |
|
| 94 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
|
|
|
| 87 |
attention_mask = batch['attention_mask'].to(self.device)
|
| 88 |
|
| 89 |
img_embs = self.vision_encoder(images, device=self.device)
|
| 90 |
+
print("img_embs",img_embs.size)
|
| 91 |
img_embs = self.mlp(img_embs)
|
| 92 |
+
print("img_embs",img_embs.size)
|
| 93 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
| 94 |
|
| 95 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|