ydshieh commited on
Commit
06bcf58
·
1 Parent(s): 3b81fb5

update check against pytorch's version

Browse files
Files changed (1) hide show
  1. tests/test_model.py +28 -9
tests/test_model.py CHANGED
@@ -193,20 +193,39 @@ from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
193
 
194
  vision_model_pt = ViTModel.from_pretrained(vision_model_name)
195
  config = GPT2Config.from_pretrained(text_model_name)
196
- config.is_encoder_decoder = True
197
  config.add_cross_attention = True
198
  text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config)
199
 
200
- encoder_inputs_pt = feature_extractor(images=image, return_tensors="pt")
201
- vision_model_pt_outputs = vision_model_pt(**encoder_inputs)
 
202
 
203
- generated = text_model_pt.generate(encoder_outputs=vision_model_pt_outputs, **gen_kwargs)
204
- token_ids = np.array(generated.sequences)[0]
 
 
 
 
 
 
 
 
 
 
205
 
206
  print('=' * 60)
207
- print(f'Pytorch\'s GPT2 LM generated token ids: {token_ids}')
 
 
208
 
209
- caption = tokenizer.decode(token_ids)
 
210
 
211
- print('=' * 60)
212
- print(f'Pytorch\'s GPT2 LM generated caption: {caption}')
 
 
 
 
 
 
193
 
194
  vision_model_pt = ViTModel.from_pretrained(vision_model_name)
195
  config = GPT2Config.from_pretrained(text_model_name)
196
+ # config.is_encoder_decoder = True
197
  config.add_cross_attention = True
198
  text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config)
199
 
200
+ encoder_pt_inputs = feature_extractor(images=image, return_tensors="pt")
201
+ encoder_pt_outputs = vision_model_pt(**encoder_pt_inputs)
202
+ encoder_hidden_states = encoder_pt_outputs.last_hidden_state
203
 
204
+ # model data
205
+ text_model_pt_inputs = {
206
+ 'input_ids': torch.tensor(decoder_input_ids, dtype=torch.int32),
207
+ 'attention_mask': torch.tensor(decoder_attention_mask, dtype=torch.int32),
208
+ 'position_ids': None,
209
+ 'encoder_hidden_states': encoder_hidden_states
210
+ }
211
+
212
+ # Model call
213
+ text_model_pt_outputs = text_model_pt(**text_model_pt_inputs)
214
+ logits = text_model_pt_outputs[0]
215
+ preds = np.argmax(logits.detach().numpy(), axis=-1)
216
 
217
  print('=' * 60)
218
+ print('PyTroch: Vit --> GPT2-LM')
219
+ print('predicted token ids:')
220
+ print(preds)
221
 
222
+ #generated = text_model_pt.generate(encoder_outputs=vision_model_pt_outputs, **gen_kwargs)
223
+ #token_ids = np.array(generated.sequences)[0]
224
 
225
+ #print('=' * 60)
226
+ #print(f'Pytorch\'s GPT2 LM generated token ids: {token_ids}')
227
+
228
+ #caption = tokenizer.decode(token_ids)
229
+
230
+ #print('=' * 60)
231
+ #print(f'Pytorch\'s GPT2 LM generated caption: {caption}')