|
|
import sys |
|
|
|
|
|
sys.path.append(".") |
|
|
|
|
|
import pytest |
|
|
from PIL import Image |
|
|
import requests |
|
|
import torch |
|
|
import time |
|
|
|
|
|
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor |
|
|
|
|
|
cache_dir = ".model.cache" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
sam_model = "facebook/sam-vit-base" |
|
|
captioner_model = "Salesforce/blip-image-captioning-base" |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def model(): |
|
|
model = SAMCaptionerModel.from_sam_captioner_pretrained(sam_model, captioner_model, cache_dir=cache_dir).to(device) |
|
|
return model |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def processor(): |
|
|
|
|
|
processor = SAMCaptionerProcessor.from_pretrained(sam_model, cache_dir=cache_dir) |
|
|
return processor |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [2]) |
|
|
@pytest.mark.parametrize("num_masks", [1, 2, 8]) |
|
|
|
|
|
@pytest.mark.parametrize("caption_mask_with_highest_iou", [False]) |
|
|
def test_modeling( |
|
|
batch_size, |
|
|
num_masks, |
|
|
caption_mask_with_highest_iou, |
|
|
processor: SAMCaptionerProcessor, |
|
|
model: SAMCaptionerModel, |
|
|
): |
|
|
img_url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg" |
|
|
raw_image = [Image.open(requests.get(img_url, stream=True).raw).convert("RGB")] |
|
|
input_points = [[[[500, 375]], [[500, 375]]]] |
|
|
|
|
|
raw_image = raw_image * batch_size |
|
|
input_points[0] *= num_masks |
|
|
input_points *= batch_size |
|
|
|
|
|
inputs = processor(raw_image, input_points=input_points, return_tensors="pt") |
|
|
for k, v in inputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
inputs[k] = v.to(device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model.generate(**inputs, caption_mask_with_highest_iou=caption_mask_with_highest_iou) |
|
|
|
|
|
tic = time.perf_counter() |
|
|
with torch.inference_mode(): |
|
|
outputs = model.generate(**inputs, caption_mask_with_highest_iou=caption_mask_with_highest_iou) |
|
|
toc = time.perf_counter() |
|
|
print(f"Time taken: {(toc - tic)*1000:0.4f} ms") |
|
|
|
|
|
print("tensor shapes") |
|
|
for k, v in outputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(f"{k}: {v.shape} {v.stride()}") |
|
|
|
|
|
batch_size, num_masks, num_heads, num_tokens = outputs.generate_ids.shape |
|
|
print( |
|
|
model.captioner_processor.batch_decode(outputs.generate_ids.reshape(-1, num_tokens), skip_special_tokens=True) |
|
|
) |
|
|
|