Upload train_k.py
Browse files- train_k.py +4 -2
train_k.py
CHANGED
|
@@ -156,8 +156,10 @@ for j in range(1, 179+1):
|
|
| 156 |
transform = transforms
|
| 157 |
)
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
model.config.decoder_start_token_id = tokenizer.cls_token_id
|
|
|
|
| 156 |
transform = transforms
|
| 157 |
)
|
| 158 |
|
| 159 |
+
if os.path.exists('VIT_large_gpt2_model'):
|
| 160 |
+
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('VIT_large_gpt2_model')
|
| 161 |
+
else:
|
| 162 |
+
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER)
|
| 163 |
|
| 164 |
|
| 165 |
model.config.decoder_start_token_id = tokenizer.cls_token_id
|