MinxuanQin
commited on
Commit
·
d43497c
1
Parent(s):
2e4b982
fix model load error
Browse files- app.py +2 -2
- model_loader.py +15 -1
app.py
CHANGED
|
@@ -17,7 +17,7 @@ df = pd.read_json('vqa_samples.json', orient="columns")
|
|
| 17 |
# define selector
|
| 18 |
model_name = st.sidebar.selectbox(
|
| 19 |
"Select a model: ",
|
| 20 |
-
('vilt', 'git', 'blip', 'vbert')
|
| 21 |
)
|
| 22 |
|
| 23 |
image_selector_unspecific = st.number_input(
|
|
@@ -41,4 +41,4 @@ question = st.text_input(f"Ask the model a question related to the image: \n"
|
|
| 41 |
args = load_model(model_name) # TODO: cache
|
| 42 |
answer = get_answer(args, image, question, model_name)
|
| 43 |
st.text(f"Answer by {model_name}: {answer}")
|
| 44 |
-
st.text(f"Ground truth: {label}")
|
|
|
|
| 17 |
# define selector
|
| 18 |
model_name = st.sidebar.selectbox(
|
| 19 |
"Select a model: ",
|
| 20 |
+
('vilt', 'vilt_finetuned', 'git', 'blip', 'vbert')
|
| 21 |
)
|
| 22 |
|
| 23 |
image_selector_unspecific = st.number_input(
|
|
|
|
| 41 |
args = load_model(model_name) # TODO: cache
|
| 42 |
answer = get_answer(args, image, question, model_name)
|
| 43 |
st.text(f"Answer by {model_name}: {answer}")
|
| 44 |
+
st.text(f"Ground truth (of the example): {label}")
|
model_loader.py
CHANGED
|
@@ -33,7 +33,10 @@ VQA_URL = "https://dl.fbaipublicfiles.com/pythia/data/answers_vqa.txt"
|
|
| 33 |
def load_model(name):
|
| 34 |
if name == "vilt":
|
| 35 |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 36 |
-
model = ViltForQuestionAnswering.from_pretrained("
|
|
|
|
|
|
|
|
|
|
| 37 |
elif name == "git":
|
| 38 |
processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
|
| 39 |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
|
|
@@ -155,6 +158,17 @@ def get_answer(model_loader_args, img, question, model_name):
|
|
| 155 |
logits = outputs.logits
|
| 156 |
idx = logits.argmax(-1).item()
|
| 157 |
pred = model.config.id2label[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
elif model_name == "git":
|
| 160 |
try:
|
|
|
|
| 33 |
def load_model(name):
|
| 34 |
if name == "vilt":
|
| 35 |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 36 |
+
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 37 |
+
elif name == "vilt_finetuned":
|
| 38 |
+
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 39 |
+
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
|
| 40 |
elif name == "git":
|
| 41 |
processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
|
| 42 |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
|
|
|
|
| 158 |
logits = outputs.logits
|
| 159 |
idx = logits.argmax(-1).item()
|
| 160 |
pred = model.config.id2label[idx]
|
| 161 |
+
|
| 162 |
+
elif model_name == "vilt_finetuned":
|
| 163 |
+
try:
|
| 164 |
+
encoding = processor(images=img, text=question, return_tensors="pt")
|
| 165 |
+
except Exception:
|
| 166 |
+
return err_msg()
|
| 167 |
+
else:
|
| 168 |
+
outputs = model(**encoding)
|
| 169 |
+
logits = outputs.logits
|
| 170 |
+
idx = logits.argmax(-1).item()
|
| 171 |
+
pred = model.config.id2label[idx]
|
| 172 |
|
| 173 |
elif model_name == "git":
|
| 174 |
try:
|