coco-demo / src /inference.py
evanec's picture
Upload 12 files
1809762 verified
# inference.py
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import T5TokenizerFast
from transformers.modeling_outputs import BaseModelOutput
from models.vision_t5 import VisionT5
from src.utils import load_experiment
from data.transforms import build_coco_transform
def load_image(path, preprocess):
img = Image.open(path).convert("RGB")
return preprocess(img).unsqueeze(0) # (1, 3, H, W)
@torch.no_grad()
def generate_caption(model, tokenizer, image_tensor, max_new_tokens=32, num_beams=1, device=None):
if device is None:
device = next(model.parameters()).device
model.eval()
image_tensor = image_tensor.to(device)
# Encode image
vision_out = model.vision_encoder(image_tensor)
img_embeds = vision_out["image_embeds"]
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
start_token = model.t5.config.decoder_start_token_id
# explicit decoder inputs & mask (FIXES THE ERROR)
input_ids = torch.tensor([[start_token]], device=device)
attention_mask = torch.tensor([[1]], device=device)
output_ids = model.t5.generate(
encoder_outputs=encoder_outputs,
decoder_start_token_id=start_token,
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
)
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return caption
# Batched evaluation helpers (non-breaking)
@torch.no_grad()
def load_images_batch(paths, preprocess, image_size):
resize = transforms.Resize((image_size, image_size))
tensors = []
for p in paths:
img = Image.open(p).convert("RGB")
img = resize(img)
t = preprocess(img).unsqueeze(0)
tensors.append(t)
return torch.cat(tensors, dim=0)
@torch.no_grad()
def generate_captions_batch(
model,
tokenizer,
image_batch, # (B, 3, H, W)
max_new_tokens=32,
num_beams=1,
device=None,
):
"""
Batched version of generate_caption().
Does NOT replace or modify existing generate_caption().
"""
if device is None:
device = next(model.parameters()).device
model.eval()
image_batch = image_batch.to(device)
# Encode in batch
vision_out = model.vision_encoder(image_batch)
img_embeds = vision_out["image_embeds"] # (B, D) or (B, S, D)
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds) # (B, S, d_model)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
# Build batched decoder inputs
start = model.t5.config.decoder_start_token_id
B = image_batch.size(0)
input_ids = torch.full((B, 1), start, dtype=torch.long, device=device)
attention_mask = torch.ones((B, 1), dtype=torch.long, device=device)
# Standard HF batching
output_ids = model.t5.generate(
encoder_outputs=encoder_outputs,
decoder_start_token_id=start,
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
)
# Decode individually
return [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True, help="Path to image")
parser.add_argument("--checkpoint", type=str, default="checkpoints/vision_t5")
args = parser.parse_args()
# Load model + tokenizer + config
model, tokenizer, meta, config = load_experiment(args.checkpoint)
image_size = config["model"].get("image_size", 224)
preprocess = build_coco_transform(image_size)
# Load image
image_tensor = load_image(args.image, preprocess)
# Generate caption
caption = generate_caption(model, tokenizer, image_tensor)
print("\nCaption:", caption)