| | """ |
| | # |
| | # Copyright (c) 2022 salesforce.com, inc. |
| | # All rights reserved. |
| | # SPDX-License-Identifier: BSD-3-Clause |
| | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| | # |
| | |
| | Integration tests for BLIP models. |
| | """ |
| |
|
| | import pytest |
| | import torch |
| | from lavis.models import load_model, load_model_and_preprocess |
| | from PIL import Image |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | raw_image = Image.open("docs/_static/merlion.png").convert("RGB") |
| |
|
| | precision = 1e-1 |
| |
|
| |
|
| | class TestBlip: |
| | def test_caption(self): |
| | |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip_caption", model_type="base_coco", is_eval=True, device=device |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["a large fountain spewing water into the air"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|
| | def test_caption_large(self): |
| | |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip_caption", model_type="large_coco", is_eval=True, device=device |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["a large statue of a person spraying water from a fountain"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|
| | def test_caption_forward(self): |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip_caption", model_type="base_coco", is_eval=True, device=device |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| | text_input = ["a large statue of a person spraying water from a fountain"] |
| |
|
| | samples = {"image": image, "text_input": text_input} |
| | output = model(samples) |
| |
|
| | assert output.intermediate_output.image_embeds.shape == torch.Size( |
| | [1, 577, 768] |
| | ) |
| | assert output.intermediate_output.decoder_labels.shape == torch.Size([1, 13]) |
| |
|
| | assert pytest.approx(2.7152, precision) == output.loss.item() |
| | assert ( |
| | pytest.approx(-0.0200, precision) |
| | == torch.mean(output.intermediate_output.image_embeds).item() |
| | ) |
| |
|
| | assert all( |
| | output.intermediate_output.decoder_labels[0] |
| | == torch.LongTensor( |
| | [ |
| | -100, |
| | -100, |
| | -100, |
| | -100, |
| | 1997, |
| | 1037, |
| | 2711, |
| | 29035, |
| | 2300, |
| | 2013, |
| | 1037, |
| | 9545, |
| | 102, |
| | ] |
| | ).to(device) |
| | ) |
| |
|
| | def test_vqa(self): |
| | model, vis_processors, txt_processors = load_model_and_preprocess( |
| | name="blip_vqa", model_type="vqav2", is_eval=True, device=device |
| | ) |
| |
|
| | |
| | question = "Which city is this photo taken?" |
| |
|
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| | question = txt_processors["eval"](question) |
| |
|
| | samples = {"image": image, "text_input": question} |
| |
|
| | answer = model.predict_answers( |
| | samples=samples, |
| | inference_method="generate", |
| | ) |
| | assert answer == ["singapore"] |
| |
|
| | answer_list = ["Singapore", "London", "Palo Alto", "Tokyo"] |
| | answers = model.predict_answers(samples, answer_list=answer_list) |
| |
|
| | assert answers == ["Singapore"] |
| |
|
| | def test_retrieval(self): |
| | model = load_model("blip_retrieval", "coco", is_eval=True, device=device) |
| |
|
| | images = torch.randn(4, 3, 384, 384).to(device) |
| | text_input = [ |
| | "caption of image 1", |
| | "another caption of image 1", |
| | "caption of image 2", |
| | "caption of image 3", |
| | ] |
| | image_id = torch.tensor([1, 1, 2, 3]).to(device) |
| | samples = { |
| | "image": images, |
| | "text_input": text_input, |
| | "image_id": image_id, |
| | "epoch": 0, |
| | "iters": 0, |
| | "num_iters_per_epoch": 100, |
| | } |
| | output = model(samples) |
| |
|
| | assert output.intermediate_output.image_embeds.shape == torch.Size( |
| | [4, 577, 768] |
| | ) |
| | assert output.intermediate_output.text_embeds.shape == torch.Size([4, 35, 768]) |
| | assert output.intermediate_output.image_embeds_m.shape == torch.Size( |
| | [4, 577, 768] |
| | ) |
| | assert output.intermediate_output.text_embeds_m.shape == torch.Size( |
| | [4, 35, 768] |
| | ) |
| | assert ( |
| | output.intermediate_output.encoder_output.last_hidden_state.shape |
| | == torch.Size([4, 35, 768]) |
| | ) |
| | assert output.intermediate_output.itm_logits.shape == torch.Size([12, 2]) |
| | assert all( |
| | output.intermediate_output.itm_labels |
| | == torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]).to(device) |
| | ) |
| |
|
| | def test_pretrain(self): |
| | model = load_model("blip_pretrain", "base", is_eval=True, device=device) |
| |
|
| | images = torch.randn(4, 3, 224, 224).to(device) |
| |
|
| | text_input = [ |
| | "caption of image 1", |
| | "another caption of image 1", |
| | "caption of image 2", |
| | "caption of image 3", |
| | ] |
| | samples = { |
| | "image": images, |
| | "text_input": text_input, |
| | "epoch": 0, |
| | "iters": 0, |
| | "num_iters_per_epoch": 100, |
| | } |
| | output = model(samples) |
| |
|
| | assert output.intermediate_output.image_embeds.shape == torch.Size( |
| | [4, 197, 768] |
| | ) |
| | assert output.intermediate_output.text_embeds.shape == torch.Size([4, 30, 768]) |
| | assert output.intermediate_output.itm_logits.shape == torch.Size([12, 2]) |
| | assert all( |
| | output.intermediate_output.itm_labels |
| | == torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]).to(device) |
| | ) |
| | assert output.intermediate_output.decoder_labels.shape == torch.Size([4, 30]) |
| | assert output.intermediate_output.decoder_output.logits.shape == torch.Size( |
| | [4, 30, 30524] |
| | ) |
| |
|
| | def test_feature_extractor(self): |
| | from PIL import Image |
| | from lavis.models import load_model_and_preprocess |
| |
|
| | raw_image = Image.open("docs/_static/merlion.png").convert("RGB") |
| | caption = "a large fountain spewing water into the air" |
| |
|
| | model, vis_processors, txt_processors = load_model_and_preprocess( |
| | "blip_feature_extractor", model_type="base", is_eval=True, device=device |
| | ) |
| |
|
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| | text_input = txt_processors["eval"](caption) |
| |
|
| | sample = {"image": image, "text_input": [text_input]} |
| |
|
| | features_multimodal = model.extract_features(sample) |
| | features_text = model.extract_features(sample, mode="text") |
| | features_image = model.extract_features(sample, mode="image") |
| |
|
| | assert features_multimodal.image_embeds.shape == torch.Size([1, 197, 768]) |
| | assert features_multimodal.multimodal_embeds.shape == torch.Size([1, 12, 768]) |
| |
|
| | assert features_text.text_embeds.shape == torch.Size([1, 12, 768]) |
| | assert features_text.text_embeds_proj.shape == torch.Size([1, 12, 256]) |
| |
|
| | assert features_image.image_embeds.shape == torch.Size([1, 197, 768]) |
| | assert features_image.image_embeds_proj.shape == torch.Size([1, 197, 256]) |
| |
|
| | assert torch.mean(features_multimodal.image_embeds).item() == pytest.approx( |
| | -0.02032, precision |
| | ) |
| | assert torch.mean( |
| | features_multimodal.multimodal_embeds |
| | ).item() == pytest.approx(-0.00095, precision) |
| |
|
| | assert torch.mean(features_text.text_embeds).item() == pytest.approx( |
| | -6.6098e-5, precision |
| | ) |
| | assert torch.mean(features_text.text_embeds_proj).item() == pytest.approx( |
| | -0.002149, precision |
| | ) |
| |
|
| | assert torch.mean(features_image.image_embeds).item() == pytest.approx( |
| | -0.02032, precision |
| | ) |
| | assert torch.mean(features_image.image_embeds_proj).item() == pytest.approx( |
| | -0.0023, precision |
| | ) |
| |
|
| | def test_itm(self): |
| | from PIL import Image |
| | from lavis.models import load_model_and_preprocess |
| |
|
| | def compute_itm(): |
| | img = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| | txt = txt_processors["eval"](caption) |
| |
|
| | itm_output = model({"image": img, "text_input": [txt]}, match_head="itm") |
| | itm_scores = torch.nn.functional.softmax(itm_output, dim=1) |
| |
|
| | return itm_scores |
| |
|
| | def compute_itc(): |
| | img = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| | txt = txt_processors["eval"](caption) |
| |
|
| | itc_score = model({"image": img, "text_input": [txt]}, match_head="itc") |
| |
|
| | return itc_score |
| |
|
| | raw_image = Image.open("docs/_static/merlion.png").convert("RGB") |
| |
|
| | model, vis_processors, txt_processors = load_model_and_preprocess( |
| | "blip_image_text_matching", model_type="base", is_eval=True, device=device |
| | ) |
| |
|
| | caption = "merlion in Singapore" |
| | itm_scores = compute_itm() |
| | itc_score = compute_itc() |
| |
|
| | assert itm_scores[:, 1].item() == pytest.approx(0.98613, abs=1e-5) |
| | assert itc_score.item() == pytest.approx(0.4633, abs=1e-4) |
| |
|
| | caption = "a random irrelevant caption" |
| | itm_scores = compute_itm() |
| | itc_score = compute_itc() |
| |
|
| | assert itm_scores[:, 1].item() == pytest.approx(0.05704, abs=1e-5) |
| | assert itc_score.item() == pytest.approx(0.23282, abs=1e-5) |
| |
|
| | |
| | model, vis_processors, txt_processors = load_model_and_preprocess( |
| | "blip_image_text_matching", model_type="large", is_eval=True, device=device |
| | ) |
| |
|
| | caption = "merlion in Singapore" |
| | itm_scores = compute_itm() |
| | itc_score = compute_itc() |
| |
|
| | assert itm_scores[:, 1].item() == pytest.approx(0.99466, abs=1e-5) |
| | assert itc_score.item() == pytest.approx(0.4474, abs=1e-4) |
| |
|
| | caption = "a random irrelevant caption" |
| | itm_scores = compute_itm() |
| | itc_score = compute_itc() |
| |
|
| | assert itm_scores[:, 1].item() == pytest.approx(0.04744, abs=1e-5) |
| | assert itc_score.item() == pytest.approx(0.12821, abs=1e-5) |
| |
|