MinxuanQin
commited on
Commit
·
58c2c99
1
Parent(s):
7b4b5f6
add display visualbert
Browse files- model_loader.py +3 -1
model_loader.py
CHANGED
|
@@ -5,6 +5,7 @@ from datasets import load_dataset, get_dataset_split_names
|
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
import requests
|
|
|
|
| 8 |
from transformers import ViltProcessor, ViltForQuestionAnswering
|
| 9 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 10 |
from transformers import BlipProcessor, BlipForQuestionAnswering
|
|
@@ -87,6 +88,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
|
|
| 87 |
)
|
| 88 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
| 89 |
.squeeze(2, 3).unsqueeze(0)
|
|
|
|
| 90 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 91 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 92 |
upd_dict = {
|
|
@@ -95,7 +97,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
|
|
| 95 |
"visual_attention_mask": visual_attention_mask,
|
| 96 |
}
|
| 97 |
inputs.update(upd_dict)
|
| 98 |
-
|
| 99 |
return upd_dict, inputs
|
| 100 |
|
| 101 |
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
import requests
|
| 8 |
+
import streamlit as st
|
| 9 |
from transformers import ViltProcessor, ViltForQuestionAnswering
|
| 10 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 11 |
from transformers import BlipProcessor, BlipForQuestionAnswering
|
|
|
|
| 88 |
)
|
| 89 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
| 90 |
.squeeze(2, 3).unsqueeze(0)
|
| 91 |
+
st.text(f"ques embed: {inputs.shape}, visual: {visual_embeds.shape}")
|
| 92 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 93 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 94 |
upd_dict = {
|
|
|
|
| 97 |
"visual_attention_mask": visual_attention_mask,
|
| 98 |
}
|
| 99 |
inputs.update(upd_dict)
|
| 100 |
+
|
| 101 |
return upd_dict, inputs
|
| 102 |
|
| 103 |
|