MinxuanQin
commited on
Commit
·
a5ab0ec
1
Parent(s):
58c2c99
debug vbert
Browse files- model_loader.py +10 -10
model_loader.py
CHANGED
|
@@ -189,16 +189,16 @@ def get_answer(model_loader_args, img, question, model_name):
|
|
| 189 |
|
| 190 |
elif model_name == "vbert":
|
| 191 |
vqa_answers = get_data(VQA_URL)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
except Exception:
|
| 198 |
-
|
| 199 |
-
else:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
|
| 203 |
elif model_name == "blip":
|
| 204 |
try:
|
|
|
|
| 189 |
|
| 190 |
elif model_name == "vbert":
|
| 191 |
vqa_answers = get_data(VQA_URL)
|
| 192 |
+
|
| 193 |
+
# load question and image (processor = tokenizer)
|
| 194 |
+
## MOD Minxuan: fix error
|
| 195 |
+
_, inputs = get_item(img, question, processor, "resnet50")
|
| 196 |
+
outputs = model(**inputs)
|
| 197 |
+
#except Exception:
|
| 198 |
+
# return err_msg()
|
| 199 |
+
# else:
|
| 200 |
+
answer_idx = torch.argmax(outputs.logits, dim=1).item() # from 3129
|
| 201 |
+
pred = vqa_answers[answer_idx]
|
| 202 |
|
| 203 |
elif model_name == "blip":
|
| 204 |
try:
|