Vasudevakrishna commited on
Commit
1cd5299
·
verified ·
1 Parent(s): 715a82c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -73,19 +73,22 @@ def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
73
  inputs_embeddings.append(end_iq_embeds)
74
  # Combine embeddings
75
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
76
- print("---------",combined_embeds.shape)
77
-
78
- for pos in range(max_tokens - 1):
79
- model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
80
- print(model_output_logits.shape)
81
- predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
82
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
83
- predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
84
- print(predicted_caption)
85
- next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
86
- combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
87
- print("combined_embeds", combined_embeds.shape)
88
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
 
 
 
89
  predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
90
  return predicted_captions_decoded
91
 
 
73
  inputs_embeddings.append(end_iq_embeds)
74
  # Combine embeddings
75
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
76
+ predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
77
+ max_new_tokens=max_tokens,
78
+ return_dict_in_generate = True)
79
+
80
+ # for pos in range(max_tokens - 1):
81
+ # model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
82
+ # print(model_output_logits.shape)
83
+ # predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
84
+ # predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
85
+ # predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
86
+ # print(predicted_caption)
87
+ # next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
88
+ # combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
89
+ # print("combined_embeds", combined_embeds.shape)
90
+ # predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
91
+ predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
92
  predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
93
  return predicted_captions_decoded
94