File size: 6,974 Bytes
f9b8c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
step2_prepare_data.py
======================
Task 3 β€” Component 2: Prepare 500 COCO validation images for inference.

Loads 500 images from the MS-COCO 2017 validation split (via HuggingFace
Datasets) and wraps them in a standard PyTorch DataLoader.

Public API
----------
    load_val_data(processor, n=500, batch_size=8, seed=42)
        -> torch.utils.data.DataLoader

Each batch yields a dict:
    {
        "pixel_values" : FloatTensor (B, 3, 384, 384),
        "labels"       : LongTensor  (B, max_len),      # reference caption ids
        "captions"     : list[str]                       # raw reference strings
    }

Standalone usage
----------------
    export PYTHONPATH=.
    venv/bin/python task/task_03/step2_prepare_data.py
"""

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BlipProcessor


# ─────────────────────────────────────────────────────────────────────────────
# Dataset wrapper
# ─────────────────────────────────────────────────────────────────────────────

DATASET_ID = "nlphuji/flickr30k"   # fallback if COCO unavailable
COCO_ID    = "phiyodr/coco2017"


class COCOValDataset(Dataset):
    """
    Wraps a HuggingFace dataset split into a torch Dataset.

    Args:
        hf_dataset : HuggingFace Dataset object with 'image' and 'captions' fields.
        processor  : BlipProcessor instance.
        max_len    : Maximum tokenization length for reference captions.
    """

    def __init__(self, hf_dataset, processor: BlipProcessor, max_len: int = 64):
        self.data      = hf_dataset
        self.processor = processor
        self.max_len   = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        image   = example["image"].convert("RGB")

        # Pick the first reference caption
        captions = example.get("captions", example.get("caption", ["<no caption>"]))
        if isinstance(captions, str):
            captions = [captions]
        caption = captions[0]

        enc = self.processor(
            images=image,
            text=caption,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
        )

        return {
            "pixel_values": enc["pixel_values"].squeeze(0),   # (3, H, W)
            "labels":       enc["input_ids"].squeeze(0),       # (max_len,)
            "caption":      caption,
        }


def _collate_fn(batch):
    return {
        "pixel_values": torch.stack([b["pixel_values"] for b in batch]),
        "labels":       torch.stack([b["labels"] for b in batch]),
        "captions":     [b["caption"] for b in batch],
    }


# ─────────────────────────────────────────────────────────────────────────────
# Public loader
# ─────────────────────────────────────────────────────────────────────────────

def load_val_data(
    processor: BlipProcessor,
    n: int         = 500,
    batch_size: int = 8,
    seed: int       = 42,
    max_len: int    = 64,
) -> DataLoader:
    """
    Download and prepare n COCO validation images.

    Falls back to Flickr30k if COCO is unavailable (e.g. firewall/proxy).

    Args:
        processor  : BlipProcessor (from step1_load_model)
        n          : Number of validation images to use (default 500)
        batch_size : DataLoader batch size
        seed       : Random seed for reproducible shuffle
        max_len    : Max caption token length for labels

    Returns:
        DataLoader that yields batches with keys:
            pixel_values, labels, captions
    """
    from datasets import load_dataset

    print("=" * 60)
    print("  Task 3 β€” Step 2: Prepare Validation Data")
    print("=" * 60)
    print(f"  Target images : {n}")
    print(f"  Batch size    : {batch_size}")

    # ── Try COCO first ────────────────────────────────────────────────────────
    ds = None
    try:
        print(f"  Loading dataset: {COCO_ID} ...")
        raw = load_dataset(COCO_ID, split="validation", trust_remote_code=True)
        ds  = raw.shuffle(seed=seed).select(range(min(n, len(raw))))
        print(f"  βœ… COCO loaded  ({len(ds)} images)")
    except Exception as e:
        print(f"  ⚠️  COCO unavailable ({e}). Falling back to Flickr30k …")

    # ── Fallback to Flickr30k ─────────────────────────────────────────────────
    if ds is None:
        raw = load_dataset(DATASET_ID, split="test", trust_remote_code=True)
        ds  = raw.shuffle(seed=seed).select(range(min(n, len(raw))))
        print(f"  βœ… Flickr30k loaded  ({len(ds)} images)")

    dataset    = COCOValDataset(ds, processor, max_len=max_len)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=False,
        collate_fn=_collate_fn,
    )

    print(f"  Batches       : {len(dataloader)}")
    print("=" * 60)
    return dataloader


# ─────────────────────────────────────────────────────────────────────────────
# Standalone entrypoint
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    from step1_load_model import load_model

    _, processor, _ = load_model()
    loader = load_val_data(processor, n=500, batch_size=8)

    # Peek at first batch
    batch = next(iter(loader))
    print(f"\nβœ…  DataLoader ready!")
    print(f"   pixel_values shape : {batch['pixel_values'].shape}")
    print(f"   labels shape       : {batch['labels'].shape}")
    print(f"   Sample caption     : {batch['captions'][0][:80]}")
    print(f"\nImport in notebooks:")
    print("  from task.task_03.step2_prepare_data import load_val_data")