Update modeling_gpt2vision.py
Browse files- modeling_gpt2vision.py +7 -5
modeling_gpt2vision.py
CHANGED
|
@@ -26,17 +26,19 @@ class GPT2Vision(PreTrainedModel):
|
|
| 26 |
def __init__(self, config):
|
| 27 |
super().__init__(config)
|
| 28 |
self.vision_encoder = VisionEncoder()
|
| 29 |
-
|
| 30 |
-
self.tokenizer = tokenizer
|
| 31 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 32 |
-
self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
| 33 |
-
|
| 34 |
if isinstance(config.gpt2_config, dict):
|
| 35 |
gpt2_config = GPT2Config(**config.gpt2_config)
|
| 36 |
else:
|
| 37 |
gpt2_config = config.gpt2_config
|
| 38 |
self.text_model = GPT2LMHeadModel(gpt2_config)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
@property
|
| 41 |
def device(self):
|
| 42 |
return self.text_model.device
|
|
|
|
| 26 |
def __init__(self, config):
|
| 27 |
super().__init__(config)
|
| 28 |
self.vision_encoder = VisionEncoder()
|
| 29 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if isinstance(config.gpt2_config, dict):
|
| 31 |
gpt2_config = GPT2Config(**config.gpt2_config)
|
| 32 |
else:
|
| 33 |
gpt2_config = config.gpt2_config
|
| 34 |
self.text_model = GPT2LMHeadModel(gpt2_config)
|
| 35 |
|
| 36 |
+
self.text_model.resize_token_embeddings(len(tokenizer))
|
| 37 |
+
self.tokenizer = tokenizer
|
| 38 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 39 |
+
self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
@property
|
| 43 |
def device(self):
|
| 44 |
return self.text_model.device
|