|
|
import sys |
|
|
|
|
|
sys.path.append(".") |
|
|
|
|
|
import pytest |
|
|
from PIL import Image |
|
|
import requests |
|
|
import torch |
|
|
import time |
|
|
|
|
|
from src.models.sca import ScaConfig, ScaModel, ScaProcessor |
|
|
from typing import Sequence |
|
|
import numpy as np |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
import torch |
|
|
import transformers |
|
|
|
|
|
cache_dir = ".model.cache" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
sam_model_name = "facebook/sam-vit-base" |
|
|
text_model_name = "gpt2" |
|
|
additional_num_hidden_layers = 2 |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def model(): |
|
|
model = ScaModel.from_sam_text_pretrained( |
|
|
sam_model_name, text_model_name, additional_num_hidden_layers, cache_dir=cache_dir |
|
|
).to(device) |
|
|
return model |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def processor(): |
|
|
processor = ScaProcessor.from_sam_text_pretrained(sam_model_name, text_model_name, cache_dir=cache_dir) |
|
|
return processor |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sam_model(): |
|
|
model = transformers.AutoModel.from_pretrained(sam_model_name, cache_dir=cache_dir).to(device) |
|
|
return model |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 3]) |
|
|
@pytest.mark.parametrize("num_masks", [4, 7]) |
|
|
def test_modeling(batch_size, num_masks, model, processor): |
|
|
img_url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg" |
|
|
raw_image = [Image.open(requests.get(img_url, stream=True).raw).convert("RGB")] |
|
|
input_points = [[[[500, 375]]]] |
|
|
raw_text = [["This is a test sentence."]] |
|
|
|
|
|
raw_image = raw_image * batch_size |
|
|
|
|
|
input_points = np.array(input_points) |
|
|
raw_text = np.array(raw_text, dtype=object) |
|
|
input_points = input_points.repeat(batch_size, axis=0).repeat(num_masks, axis=1).tolist() |
|
|
raw_text = raw_text.repeat(batch_size, axis=0).repeat(num_masks, axis=1).reshape(-1).tolist() |
|
|
|
|
|
inputs = processor(raw_image, input_points=input_points, return_tensors="pt") |
|
|
|
|
|
|
|
|
tokenizer = processor.tokenizer |
|
|
raw_text_inputs = tokenizer(raw_text) |
|
|
eos_token_id = tokenizer.eos_token_id |
|
|
if eos_token_id is None: |
|
|
raise ValueError("tokenizer does not have an eos token id") |
|
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_token_id |
|
|
label_pad_token_id = -100 |
|
|
|
|
|
tokenized_inputs = tokenizer(raw_text) |
|
|
raw_input_ids = tokenized_inputs["input_ids"] |
|
|
raw_attention_mask = tokenized_inputs["attention_mask"] |
|
|
|
|
|
for i in range(len(raw_input_ids)): |
|
|
raw_input_ids[i] += [eos_token_id] |
|
|
raw_attention_mask[i] += [1] |
|
|
|
|
|
max_length = tokenizer.model_max_length |
|
|
for i in range(len(raw_input_ids)): |
|
|
raw_input_ids[i] = raw_input_ids[i][:max_length] |
|
|
raw_attention_mask[i] = raw_attention_mask[i][:max_length] |
|
|
|
|
|
input_ids = pad_sequence([torch.tensor(x) for x in raw_input_ids], batch_first=True, padding_value=pad_token_id) |
|
|
attention_mask = pad_sequence([torch.tensor(x) for x in raw_attention_mask], batch_first=True, padding_value=0) |
|
|
|
|
|
labels = pad_sequence([torch.tensor(x) for x in raw_input_ids], batch_first=True, padding_value=label_pad_token_id) |
|
|
labels = torch.nn.functional.pad(labels, (1, 0), value=label_pad_token_id) |
|
|
|
|
|
text_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
|
|
|
|
|
for k in text_inputs: |
|
|
text_inputs[k] = text_inputs[k].view(batch_size, num_masks, -1) |
|
|
|
|
|
for k, v in inputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(k, v.shape) |
|
|
elif isinstance(v, Sequence): |
|
|
print(k, "sequence of ", type(v[0]), len(v)) |
|
|
else: |
|
|
print(k, type(v), len(v)) |
|
|
for k, v in text_inputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(k, v.shape) |
|
|
elif isinstance(v, Sequence): |
|
|
print(k, "sequence of ", type(v[0]), len(v)) |
|
|
else: |
|
|
print(k, type(v), len(v)) |
|
|
|
|
|
for k, v in inputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
inputs[k] = v.to(device) |
|
|
for k, v in text_inputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
text_inputs[k] = v.to(device) |
|
|
|
|
|
|
|
|
model.train() |
|
|
outputs = model(**inputs, **text_inputs) |
|
|
for k, v in outputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(k, v.shape) |
|
|
sequence_texts = tokenizer.batch_decode(outputs["logits"].argmax(dim=-1)) |
|
|
sequence_texts = sequence_texts[:1] |
|
|
print(sequence_texts) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, **text_inputs) |
|
|
for k, v in outputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(k, v.shape) |
|
|
outputs["sequences"] = outputs["sequences"].view(-1, outputs["sequences"].shape[-1]) |
|
|
sequence_texts = tokenizer.batch_decode(outputs["sequences"]) |
|
|
sequence_texts = sequence_texts[:1] |
|
|
print(sequence_texts) |
|
|
|
|
|
|
|
|
inputs_embeds = outputs["projected_query_logits"][0, 0, 0:1] |
|
|
inputs_ids = torch.tensor([[tokenizer.eos_token_id]]).to(device) |
|
|
attention_masks = torch.tensor([[1]]).to(device) |
|
|
|
|
|
language_model = transformers.AutoModelForCausalLM.from_pretrained( |
|
|
text_model_name, config=model.config.text_config, cache_dir=cache_dir |
|
|
).to(device) |
|
|
language_model.eval() |
|
|
with torch.no_grad(): |
|
|
original_output = language_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_masks) |
|
|
sca_text_output = model.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_masks) |
|
|
assert torch.allclose(original_output, sca_text_output) |
|
|
|
|
|
language_model = transformers.AutoModelForCausalLM.from_pretrained(text_model_name, cache_dir=cache_dir).to(device) |
|
|
language_model.eval() |
|
|
with torch.no_grad(): |
|
|
original_output = language_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_masks) |
|
|
assert torch.allclose(original_output, sca_text_output) |
|
|
|
|
|
validate_texts = tokenizer.batch_decode(sca_text_output) |
|
|
validate_texts = validate_texts[:1] |
|
|
print(validate_texts) |
|
|
|
|
|
assert validate_texts[0] == sequence_texts[0] |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 3]) |
|
|
@pytest.mark.parametrize("num_masks", [4, 7]) |
|
|
def test_modeling_with_sam(batch_size, num_masks, model, sam_model, processor): |
|
|
img_url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg" |
|
|
raw_image = [Image.open(requests.get(img_url, stream=True).raw).convert("RGB")] |
|
|
input_points = [[[[500, 375]]]] |
|
|
raw_text = [["This is a test sentence."]] |
|
|
|
|
|
raw_image = raw_image * batch_size |
|
|
|
|
|
input_points = np.array(input_points) |
|
|
raw_text = np.array(raw_text, dtype=object) |
|
|
input_points = input_points.repeat(batch_size, axis=0).repeat(num_masks, axis=1).tolist() |
|
|
raw_text = raw_text.repeat(batch_size, axis=0).repeat(num_masks, axis=1).reshape(-1).tolist() |
|
|
|
|
|
inputs = processor(raw_image, input_points=input_points, return_tensors="pt") |
|
|
|
|
|
for k, v in inputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
inputs[k] = v.to(device) |
|
|
|
|
|
sam_model.train() |
|
|
model.train() |
|
|
sam_output = sam_model(**inputs) |
|
|
sca_output = model(**inputs) |
|
|
sam_output_from_sca = sca_output.segmentation_outputs |
|
|
for k in sam_output: |
|
|
if isinstance(sam_output[k], torch.Tensor): |
|
|
assert torch.allclose(sam_output[k], sam_output_from_sca[k]) |
|
|
|
|
|
sam_model.eval() |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
sam_output = sam_model(**inputs) |
|
|
sca_output = model.generate(**inputs) |
|
|
sam_output_from_sca = sca_output.segmentation_outputs |
|
|
for k in sam_output: |
|
|
if isinstance(sam_output[k], torch.Tensor): |
|
|
assert torch.allclose(sam_output[k], sam_output_from_sca[k]) |
|
|
|