File size: 4,163 Bytes
1809762 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# inference.py
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import T5TokenizerFast
from transformers.modeling_outputs import BaseModelOutput
from models.vision_t5 import VisionT5
from src.utils import load_experiment
from data.transforms import build_coco_transform
def load_image(path, preprocess):
img = Image.open(path).convert("RGB")
return preprocess(img).unsqueeze(0) # (1, 3, H, W)
@torch.no_grad()
def generate_caption(model, tokenizer, image_tensor, max_new_tokens=32, num_beams=1, device=None):
if device is None:
device = next(model.parameters()).device
model.eval()
image_tensor = image_tensor.to(device)
# Encode image
vision_out = model.vision_encoder(image_tensor)
img_embeds = vision_out["image_embeds"]
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
start_token = model.t5.config.decoder_start_token_id
# explicit decoder inputs & mask (FIXES THE ERROR)
input_ids = torch.tensor([[start_token]], device=device)
attention_mask = torch.tensor([[1]], device=device)
output_ids = model.t5.generate(
encoder_outputs=encoder_outputs,
decoder_start_token_id=start_token,
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
)
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return caption
# Batched evaluation helpers (non-breaking)
@torch.no_grad()
def load_images_batch(paths, preprocess, image_size):
resize = transforms.Resize((image_size, image_size))
tensors = []
for p in paths:
img = Image.open(p).convert("RGB")
img = resize(img)
t = preprocess(img).unsqueeze(0)
tensors.append(t)
return torch.cat(tensors, dim=0)
@torch.no_grad()
def generate_captions_batch(
model,
tokenizer,
image_batch, # (B, 3, H, W)
max_new_tokens=32,
num_beams=1,
device=None,
):
"""
Batched version of generate_caption().
Does NOT replace or modify existing generate_caption().
"""
if device is None:
device = next(model.parameters()).device
model.eval()
image_batch = image_batch.to(device)
# Encode in batch
vision_out = model.vision_encoder(image_batch)
img_embeds = vision_out["image_embeds"] # (B, D) or (B, S, D)
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds) # (B, S, d_model)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
# Build batched decoder inputs
start = model.t5.config.decoder_start_token_id
B = image_batch.size(0)
input_ids = torch.full((B, 1), start, dtype=torch.long, device=device)
attention_mask = torch.ones((B, 1), dtype=torch.long, device=device)
# Standard HF batching
output_ids = model.t5.generate(
encoder_outputs=encoder_outputs,
decoder_start_token_id=start,
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
)
# Decode individually
return [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True, help="Path to image")
parser.add_argument("--checkpoint", type=str, default="checkpoints/vision_t5")
args = parser.parse_args()
# Load model + tokenizer + config
model, tokenizer, meta, config = load_experiment(args.checkpoint)
image_size = config["model"].get("image_size", 224)
preprocess = build_coco_transform(image_size)
# Load image
image_tensor = load_image(args.image, preprocess)
# Generate caption
caption = generate_caption(model, tokenizer, image_tensor)
print("\nCaption:", caption)
|