Update app.py
Browse files
app.py
CHANGED
|
@@ -19,6 +19,9 @@ from tasks.mm_tasks.caption import CaptionTask
|
|
| 19 |
from tasks.mm_tasks.refcoco import RefcocoTask
|
| 20 |
from tasks.mm_tasks.vqa_gen import VqaGenTask
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
# video
|
| 23 |
from data.video_utils import VIDEO_READER_FUNCS
|
| 24 |
|
|
@@ -175,8 +178,8 @@ move2gpu(audio_caption_models, general_cfg)
|
|
| 175 |
caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
|
| 176 |
refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
|
| 177 |
vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
|
| 178 |
-
vqa_generator.zero_shot = True
|
| 179 |
-
vqa_generator.constraint_trie = None
|
| 180 |
general_generator = general_task.build_generator(general_models, general_cfg.generation)
|
| 181 |
|
| 182 |
video_caption_generator = caption_task.build_generator(video_caption_models, video_caption_cfg.generation)
|
|
@@ -449,8 +452,13 @@ def inference(image, audio, video, task_type, instruction):
|
|
| 449 |
|
| 450 |
# Generate result
|
| 451 |
with torch.no_grad():
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
if bins.strip() != '':
|
| 456 |
w, h = image.size
|
|
|
|
| 19 |
from tasks.mm_tasks.refcoco import RefcocoTask
|
| 20 |
from tasks.mm_tasks.vqa_gen import VqaGenTask
|
| 21 |
|
| 22 |
+
|
| 23 |
+
from utils.zero_shot_utils import zero_shot_step
|
| 24 |
+
|
| 25 |
# video
|
| 26 |
from data.video_utils import VIDEO_READER_FUNCS
|
| 27 |
|
|
|
|
| 178 |
caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
|
| 179 |
refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
|
| 180 |
vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
|
| 181 |
+
# vqa_generator.zero_shot = True
|
| 182 |
+
# vqa_generator.constraint_trie = None
|
| 183 |
general_generator = general_task.build_generator(general_models, general_cfg.generation)
|
| 184 |
|
| 185 |
video_caption_generator = caption_task.build_generator(video_caption_models, video_caption_cfg.generation)
|
|
|
|
| 452 |
|
| 453 |
# Generate result
|
| 454 |
with torch.no_grad():
|
| 455 |
+
if task_type == 'Visual Question Answering':
|
| 456 |
+
result, scores = zero_shot_step(vqa_task, generator, models, sample)
|
| 457 |
+
tokens = result[0]['answer']
|
| 458 |
+
bins = ''
|
| 459 |
+
else:
|
| 460 |
+
hypos = task.inference_step(generator, models, sample)
|
| 461 |
+
tokens, bins, imgs = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
|
| 462 |
|
| 463 |
if bins.strip() != '':
|
| 464 |
w, h = image.size
|