EmoVIT / LAVIS /tests /models /test_pnp_vqa.py
manhteky123's picture
Upload 1462 files
1762ba0 verified
"""
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
Integration tests for PNP-VQA model.
"""
import pytest
import torch
from lavis.models import load_model, load_model_and_preprocess
from PIL import Image
# setup device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load sample image
raw_image = Image.open("docs/_static/merlion.png").convert("RGB")
precision = 1e-1
class TestPNPVQA:
def test_vqa(self):
# loads PNP-VQA base model, with BLIP_itm_large, BLIP_caption_large, Unifiedqav2_base
# this also loads the associated image processors and text processors
model, vis_processors, txt_processors = load_model_and_preprocess(
name="pnp_vqa", model_type="base", is_eval=True, device=device
)
# ask a random question.
question = "Which city is this photo taken?"
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
question = txt_processors["eval"](question)
samples = {"image": image, "text_input": [question]}
answer, caption, gradcam = model.predict_answers(
samples=samples,
inference_method="generate",
num_captions=5,
num_patches=20,
)
assert isinstance(answer, list)
assert isinstance(caption, list)
assert isinstance(gradcam, torch.Tensor)
assert len(answer) == 1
assert len(caption) == 1
assert len(caption[0]) == 5
assert gradcam.size() == torch.Size([1,576])
def test_itm(self):
# loads PNP-VQA base model, with BLIP_itm_large, BLIP_caption_large, Unifiedqav2_base
# this also loads the associated image processors and text processors
model, vis_processors, txt_processors = load_model_and_preprocess(
name="pnp_vqa", model_type="base", is_eval=True, device=device
)
# ask a random question.
question = "Which city is this photo taken?"
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
question = txt_processors["eval"](question)
samples = {"image": image, "text_input": [question]}
samples = model.forward_itm(samples=samples)
assert samples['gradcams'].size() == torch.Size([1,576])
def test_caption(self):
# loads PNP-VQA base model, with BLIP_itm_large, BLIP_caption_large, Unifiedqav2_base
# this also loads the associated image processors and text processors
model, vis_processors, txt_processors = load_model_and_preprocess(
name="pnp_vqa", model_type="base", is_eval=True, device=device
)
# ask a random question.
question = "Which city is this photo taken?"
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
question = txt_processors["eval"](question)
samples = {"image": image, "text_input": [question]}
samples['gradcams'] = torch.rand(1,576)
samples = model.forward_cap(samples=samples, num_captions=5, num_patches=20)
assert len(samples['captions']) == 1
assert len(samples['captions'][0]) == 5
def test_qa(self):
# loads PNP-VQA base model, with BLIP_itm_large, BLIP_caption_large, Unifiedqav2_base
# this also loads the associated image processors and text processors
model, vis_processors, txt_processors = load_model_and_preprocess(
name="pnp_vqa", model_type="base", is_eval=True, device=device
)
# ask a random question.
question = "Which city is this photo taken?"
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
question = txt_processors["eval"](question)
samples = {"image": image, "text_input": [question]}
samples['captions'] = [['the city is singapore', 'the picture is taken in singapore']]
answer = model.forward_qa(samples=samples, num_captions=2)
assert isinstance(answer, list)
assert len(answer) == 1
assert answer[0]== 'singapore'