| | """ |
| | # |
| | # Copyright (c) 2022 salesforce.com, inc. |
| | # All rights reserved. |
| | # SPDX-License-Identifier: BSD-3-Clause |
| | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| | # |
| | |
| | Integration tests for BLIP2 models. |
| | """ |
| |
|
| | import pytest |
| | import torch |
| | from lavis.models import load_model, load_model_and_preprocess |
| | from PIL import Image |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | raw_image = Image.open("docs/_static/merlion.png").convert("RGB") |
| |
|
| |
|
| | class TestBlip2: |
| | def test_blip2_opt2p7b(self): |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["the merlion fountain in singapore"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|
| | def test_blip2_opt2p7b_coco(self): |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip2_opt", |
| | model_type="caption_coco_opt2.7b", |
| | is_eval=True, |
| | device=device, |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["a statue of a mermaid spraying water into the air"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|
| | def test_blip2_opt6p7b(self): |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip2_opt", model_type="pretrain_opt6.7b", is_eval=True, device=device |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["a statue of a merlion in front of a water fountain"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|
| | def test_blip2_opt6p7b_coco(self): |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip2_opt", |
| | model_type="caption_coco_opt6.7b", |
| | is_eval=True, |
| | device=device, |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["a large fountain spraying water into the air"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|
| | def test_blip2_flant5xl(self): |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["marina bay sands, singapore"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|
| | def test_blip2_flant5xxl(self): |
| | |
| | model, vis_processors, _ = load_model_and_preprocess( |
| | name="blip2_t5", |
| | model_type="pretrain_flant5xxl", |
| | is_eval=True, |
| | device=device, |
| | ) |
| |
|
| | |
| | |
| | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
| |
|
| | |
| | caption = model.generate({"image": image}) |
| |
|
| | assert caption == ["the merlion statue in singapore"] |
| |
|
| | |
| | captions = model.generate({"image": image}, num_captions=3) |
| |
|
| | assert len(captions) == 3 |
| |
|