Spaces:
Runtime error
Runtime error
| """ | |
| # | |
| # Copyright (c) 2022 salesforce.com, inc. | |
| # All rights reserved. | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| # | |
| Integration tests for BLIP2 models. | |
| """ | |
| import pytest | |
| import torch | |
| from lavis.models import load_model, load_model_and_preprocess | |
| from PIL import Image | |
| # setup device to use | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # load sample image | |
| raw_image = Image.open("docs/_static/merlion.png").convert("RGB") | |
| class TestBlip2: | |
| def test_blip2_opt2p7b(self): | |
| # loads BLIP2-OPT-2.7b caption model, without finetuning on coco. | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device | |
| ) | |
| # preprocess the image | |
| # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference) | |
| image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
| # generate caption | |
| caption = model.generate({"image": image}) | |
| assert caption == ["the merlion fountain in singapore"] | |
| # generate multiple captions | |
| captions = model.generate({"image": image}, num_captions=3) | |
| assert len(captions) == 3 | |
| def test_blip2_opt2p7b_coco(self): | |
| # loads BLIP2-OPT-2.7b caption model, | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name="blip2_opt", | |
| model_type="caption_coco_opt2.7b", | |
| is_eval=True, | |
| device=device, | |
| ) | |
| # preprocess the image | |
| # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference) | |
| image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
| # generate caption | |
| caption = model.generate({"image": image}) | |
| assert caption == ["a statue of a mermaid spraying water into the air"] | |
| # generate multiple captions | |
| captions = model.generate({"image": image}, num_captions=3) | |
| assert len(captions) == 3 | |
| def test_blip2_opt6p7b(self): | |
| # loads BLIP2-OPT-2.7b caption model, | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name="blip2_opt", model_type="pretrain_opt6.7b", is_eval=True, device=device | |
| ) | |
| # preprocess the image | |
| # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference) | |
| image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
| # generate caption | |
| caption = model.generate({"image": image}) | |
| assert caption == ["a statue of a merlion in front of a water fountain"] | |
| # generate multiple captions | |
| captions = model.generate({"image": image}, num_captions=3) | |
| assert len(captions) == 3 | |
| def test_blip2_opt6p7b_coco(self): | |
| # loads BLIP2-OPT-2.7b caption model, | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name="blip2_opt", | |
| model_type="caption_coco_opt6.7b", | |
| is_eval=True, | |
| device=device, | |
| ) | |
| # preprocess the image | |
| # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference) | |
| image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
| # generate caption | |
| caption = model.generate({"image": image}) | |
| assert caption == ["a large fountain spraying water into the air"] | |
| # generate multiple captions | |
| captions = model.generate({"image": image}, num_captions=3) | |
| assert len(captions) == 3 | |
| def test_blip2_flant5xl(self): | |
| # loads BLIP2-FLAN-T5XL caption model, | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device | |
| ) | |
| # preprocess the image | |
| # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference) | |
| image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
| # generate caption | |
| caption = model.generate({"image": image}) | |
| assert caption == ["marina bay sands, singapore"] | |
| # generate multiple captions | |
| captions = model.generate({"image": image}, num_captions=3) | |
| assert len(captions) == 3 | |
| def test_blip2_flant5xxl(self): | |
| # loads BLIP2-FLAN-T5XXL caption model, | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name="blip2_t5", | |
| model_type="pretrain_flant5xxl", | |
| is_eval=True, | |
| device=device, | |
| ) | |
| # preprocess the image | |
| # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference) | |
| image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
| # generate caption | |
| caption = model.generate({"image": image}) | |
| assert caption == ["the merlion statue in singapore"] | |
| # generate multiple captions | |
| captions = model.generate({"image": image}, num_captions=3) | |
| assert len(captions) == 3 | |