ydshieh commited on
Commit ·
06bcf58
1
Parent(s): 3b81fb5
update check against pytorch's version
Browse files- tests/test_model.py +28 -9
tests/test_model.py
CHANGED
|
@@ -193,20 +193,39 @@ from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
|
|
| 193 |
|
| 194 |
vision_model_pt = ViTModel.from_pretrained(vision_model_name)
|
| 195 |
config = GPT2Config.from_pretrained(text_model_name)
|
| 196 |
-
config.is_encoder_decoder = True
|
| 197 |
config.add_cross_attention = True
|
| 198 |
text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config)
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
print('=' * 60)
|
| 207 |
-
print(
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
|
|
|
|
| 210 |
|
| 211 |
-
print('=' * 60)
|
| 212 |
-
print(f'Pytorch\'s GPT2 LM generated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
vision_model_pt = ViTModel.from_pretrained(vision_model_name)
|
| 195 |
config = GPT2Config.from_pretrained(text_model_name)
|
| 196 |
+
# config.is_encoder_decoder = True
|
| 197 |
config.add_cross_attention = True
|
| 198 |
text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config)
|
| 199 |
|
| 200 |
+
encoder_pt_inputs = feature_extractor(images=image, return_tensors="pt")
|
| 201 |
+
encoder_pt_outputs = vision_model_pt(**encoder_pt_inputs)
|
| 202 |
+
encoder_hidden_states = encoder_pt_outputs.last_hidden_state
|
| 203 |
|
| 204 |
+
# model data
|
| 205 |
+
text_model_pt_inputs = {
|
| 206 |
+
'input_ids': torch.tensor(decoder_input_ids, dtype=torch.int32),
|
| 207 |
+
'attention_mask': torch.tensor(decoder_attention_mask, dtype=torch.int32),
|
| 208 |
+
'position_ids': None,
|
| 209 |
+
'encoder_hidden_states': encoder_hidden_states
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
# Model call
|
| 213 |
+
text_model_pt_outputs = text_model_pt(**text_model_pt_inputs)
|
| 214 |
+
logits = text_model_pt_outputs[0]
|
| 215 |
+
preds = np.argmax(logits.detach().numpy(), axis=-1)
|
| 216 |
|
| 217 |
print('=' * 60)
|
| 218 |
+
print('PyTroch: Vit --> GPT2-LM')
|
| 219 |
+
print('predicted token ids:')
|
| 220 |
+
print(preds)
|
| 221 |
|
| 222 |
+
#generated = text_model_pt.generate(encoder_outputs=vision_model_pt_outputs, **gen_kwargs)
|
| 223 |
+
#token_ids = np.array(generated.sequences)[0]
|
| 224 |
|
| 225 |
+
#print('=' * 60)
|
| 226 |
+
#print(f'Pytorch\'s GPT2 LM generated token ids: {token_ids}')
|
| 227 |
+
|
| 228 |
+
#caption = tokenizer.decode(token_ids)
|
| 229 |
+
|
| 230 |
+
#print('=' * 60)
|
| 231 |
+
#print(f'Pytorch\'s GPT2 LM generated caption: {caption}')
|