File size: 4,163 Bytes
1809762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# inference.py
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from transformers import T5TokenizerFast
from transformers.modeling_outputs import BaseModelOutput
from models.vision_t5 import VisionT5
from src.utils import load_experiment
from data.transforms import build_coco_transform



def load_image(path, preprocess):
    img = Image.open(path).convert("RGB")
    return preprocess(img).unsqueeze(0)   # (1, 3, H, W)



@torch.no_grad()
def generate_caption(model, tokenizer, image_tensor, max_new_tokens=32, num_beams=1, device=None):
    if device is None:
        device = next(model.parameters()).device

    model.eval()

    image_tensor = image_tensor.to(device)

    # Encode image
    vision_out = model.vision_encoder(image_tensor)
    img_embeds = vision_out["image_embeds"]

    if img_embeds.dim() == 2:
        img_embeds = img_embeds.unsqueeze(1)

    projected = model.projector(img_embeds)

    encoder_outputs = BaseModelOutput(last_hidden_state=projected)

    start_token = model.t5.config.decoder_start_token_id

    # explicit decoder inputs & mask (FIXES THE ERROR)
    input_ids = torch.tensor([[start_token]], device=device)
    attention_mask = torch.tensor([[1]], device=device)

    output_ids = model.t5.generate(
        encoder_outputs=encoder_outputs,
        decoder_start_token_id=start_token,
        input_ids=input_ids,
        attention_mask=attention_mask,
        num_beams=num_beams,
        max_new_tokens=max_new_tokens,
    )

    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption




# Batched evaluation helpers (non-breaking)

@torch.no_grad()
def load_images_batch(paths, preprocess, image_size):
    resize = transforms.Resize((image_size, image_size))
    tensors = []
    for p in paths:
        img = Image.open(p).convert("RGB")
        img = resize(img)
        t = preprocess(img).unsqueeze(0)
        tensors.append(t)
    return torch.cat(tensors, dim=0)


@torch.no_grad()
def generate_captions_batch(
    model,
    tokenizer,
    image_batch,          # (B, 3, H, W)
    max_new_tokens=32,
    num_beams=1,
    device=None,
):
    """
    Batched version of generate_caption().
    Does NOT replace or modify existing generate_caption().
    """

    if device is None:
        device = next(model.parameters()).device

    model.eval()

    image_batch = image_batch.to(device)

    # Encode in batch
    vision_out = model.vision_encoder(image_batch)
    img_embeds = vision_out["image_embeds"]   # (B, D) or (B, S, D)

    if img_embeds.dim() == 2:
        img_embeds = img_embeds.unsqueeze(1)

    projected = model.projector(img_embeds)    # (B, S, d_model)
    encoder_outputs = BaseModelOutput(last_hidden_state=projected)

    # Build batched decoder inputs
    start = model.t5.config.decoder_start_token_id
    B = image_batch.size(0)

    input_ids = torch.full((B, 1), start, dtype=torch.long, device=device)
    attention_mask = torch.ones((B, 1), dtype=torch.long, device=device)

    # Standard HF batching
    output_ids = model.t5.generate(
        encoder_outputs=encoder_outputs,
        decoder_start_token_id=start,
        input_ids=input_ids,
        attention_mask=attention_mask,
        num_beams=num_beams,
        max_new_tokens=max_new_tokens,
    )

    # Decode individually
    return [
        tokenizer.decode(ids, skip_special_tokens=True)
        for ids in output_ids
    ]


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument("--image", type=str, required=True, help="Path to image")
    parser.add_argument("--checkpoint", type=str, default="checkpoints/vision_t5")

    args = parser.parse_args()

    # Load model + tokenizer + config
    model, tokenizer, meta, config = load_experiment(args.checkpoint)
    image_size = config["model"].get("image_size", 224)
    preprocess = build_coco_transform(image_size)

    # Load image
    image_tensor = load_image(args.image, preprocess)

    # Generate caption
    caption = generate_caption(model, tokenizer, image_tensor)
    print("\nCaption:", caption)