Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,17 +35,17 @@ def model_generate_ans(img,val_q):
|
|
| 35 |
val_image_embeds = projection(clip_val_outputs).to(torch.float16)
|
| 36 |
|
| 37 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
| 38 |
-
img_token_embeds =
|
| 39 |
|
| 40 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
| 41 |
-
val_q_embeds =
|
| 42 |
|
| 43 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
| 44 |
|
| 45 |
predicted_caption = torch.full((1,max_generate_length),50256)
|
| 46 |
|
| 47 |
for g in range(max_generate_length):
|
| 48 |
-
phi_output_logits =
|
| 49 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
| 50 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
| 51 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|
|
|
|
| 35 |
val_image_embeds = projection(clip_val_outputs).to(torch.float16)
|
| 36 |
|
| 37 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
| 38 |
+
img_token_embeds = merged_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
| 39 |
|
| 40 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
| 41 |
+
val_q_embeds = merged_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
| 42 |
|
| 43 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
| 44 |
|
| 45 |
predicted_caption = torch.full((1,max_generate_length),50256)
|
| 46 |
|
| 47 |
for g in range(max_generate_length):
|
| 48 |
+
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
|
| 49 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
| 50 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
| 51 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|