File size: 2,501 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import sys
sys.path.append(".")
import pytest
from PIL import Image
import requests
import torch
import time
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor
cache_dir = ".model.cache"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam_model = "facebook/sam-vit-base"
captioner_model = "Salesforce/blip-image-captioning-base"
@pytest.fixture
def model():
model = SAMCaptionerModel.from_sam_captioner_pretrained(sam_model, captioner_model, cache_dir=cache_dir).to(device)
return model
@pytest.fixture
def processor():
# FIXME(xiaoke): use `from_sam_captioner_pretrained`
processor = SAMCaptionerProcessor.from_pretrained(sam_model, cache_dir=cache_dir)
return processor
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("num_masks", [1, 2, 8])
# FIXME(xiaoke): no more `caption_mask_with_highest_iou`. Remove it.
@pytest.mark.parametrize("caption_mask_with_highest_iou", [False])
def test_modeling(
batch_size,
num_masks,
caption_mask_with_highest_iou,
processor: SAMCaptionerProcessor,
model: SAMCaptionerModel,
):
img_url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg"
raw_image = [Image.open(requests.get(img_url, stream=True).raw).convert("RGB")]
input_points = [[[[500, 375]], [[500, 375]]]] # 2D location of a window in the image
raw_image = raw_image * batch_size
input_points[0] *= num_masks
input_points *= batch_size
inputs = processor(raw_image, input_points=input_points, return_tensors="pt")
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(device)
# warmup GPUs
with torch.inference_mode():
outputs = model.generate(**inputs, caption_mask_with_highest_iou=caption_mask_with_highest_iou)
tic = time.perf_counter()
with torch.inference_mode():
outputs = model.generate(**inputs, caption_mask_with_highest_iou=caption_mask_with_highest_iou)
toc = time.perf_counter()
print(f"Time taken: {(toc - tic)*1000:0.4f} ms")
print("tensor shapes")
for k, v in outputs.items():
if isinstance(v, torch.Tensor):
print(f"{k}: {v.shape} {v.stride()}")
batch_size, num_masks, num_heads, num_tokens = outputs.generate_ids.shape
print(
model.captioner_processor.batch_decode(outputs.generate_ids.reshape(-1, num_tokens), skip_special_tokens=True)
)
|