|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
from transformers import T5TokenizerFast |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
from models.vision_t5 import VisionT5 |
|
|
from src.utils import load_experiment |
|
|
from data.transforms import build_coco_transform |
|
|
|
|
|
|
|
|
|
|
|
def load_image(path, preprocess): |
|
|
img = Image.open(path).convert("RGB") |
|
|
return preprocess(img).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_caption(model, tokenizer, image_tensor, max_new_tokens=32, num_beams=1, device=None): |
|
|
if device is None: |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
model.eval() |
|
|
|
|
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
|
|
|
vision_out = model.vision_encoder(image_tensor) |
|
|
img_embeds = vision_out["image_embeds"] |
|
|
|
|
|
if img_embeds.dim() == 2: |
|
|
img_embeds = img_embeds.unsqueeze(1) |
|
|
|
|
|
projected = model.projector(img_embeds) |
|
|
|
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=projected) |
|
|
|
|
|
start_token = model.t5.config.decoder_start_token_id |
|
|
|
|
|
|
|
|
input_ids = torch.tensor([[start_token]], device=device) |
|
|
attention_mask = torch.tensor([[1]], device=device) |
|
|
|
|
|
output_ids = model.t5.generate( |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_start_token_id=start_token, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
num_beams=num_beams, |
|
|
max_new_tokens=max_new_tokens, |
|
|
) |
|
|
|
|
|
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
return caption |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def load_images_batch(paths, preprocess, image_size): |
|
|
resize = transforms.Resize((image_size, image_size)) |
|
|
tensors = [] |
|
|
for p in paths: |
|
|
img = Image.open(p).convert("RGB") |
|
|
img = resize(img) |
|
|
t = preprocess(img).unsqueeze(0) |
|
|
tensors.append(t) |
|
|
return torch.cat(tensors, dim=0) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_captions_batch( |
|
|
model, |
|
|
tokenizer, |
|
|
image_batch, |
|
|
max_new_tokens=32, |
|
|
num_beams=1, |
|
|
device=None, |
|
|
): |
|
|
""" |
|
|
Batched version of generate_caption(). |
|
|
Does NOT replace or modify existing generate_caption(). |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
model.eval() |
|
|
|
|
|
image_batch = image_batch.to(device) |
|
|
|
|
|
|
|
|
vision_out = model.vision_encoder(image_batch) |
|
|
img_embeds = vision_out["image_embeds"] |
|
|
|
|
|
if img_embeds.dim() == 2: |
|
|
img_embeds = img_embeds.unsqueeze(1) |
|
|
|
|
|
projected = model.projector(img_embeds) |
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=projected) |
|
|
|
|
|
|
|
|
start = model.t5.config.decoder_start_token_id |
|
|
B = image_batch.size(0) |
|
|
|
|
|
input_ids = torch.full((B, 1), start, dtype=torch.long, device=device) |
|
|
attention_mask = torch.ones((B, 1), dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
output_ids = model.t5.generate( |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_start_token_id=start, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
num_beams=num_beams, |
|
|
max_new_tokens=max_new_tokens, |
|
|
) |
|
|
|
|
|
|
|
|
return [ |
|
|
tokenizer.decode(ids, skip_special_tokens=True) |
|
|
for ids in output_ids |
|
|
] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--image", type=str, required=True, help="Path to image") |
|
|
parser.add_argument("--checkpoint", type=str, default="checkpoints/vision_t5") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model, tokenizer, meta, config = load_experiment(args.checkpoint) |
|
|
image_size = config["model"].get("image_size", 224) |
|
|
preprocess = build_coco_transform(image_size) |
|
|
|
|
|
|
|
|
image_tensor = load_image(args.image, preprocess) |
|
|
|
|
|
|
|
|
caption = generate_caption(model, tokenizer, image_tensor) |
|
|
print("\nCaption:", caption) |
|
|
|
|
|
|