Spaces:
Sleeping
Sleeping
File size: 6,974 Bytes
f9b8c32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
step2_prepare_data.py
======================
Task 3 β Component 2: Prepare 500 COCO validation images for inference.
Loads 500 images from the MS-COCO 2017 validation split (via HuggingFace
Datasets) and wraps them in a standard PyTorch DataLoader.
Public API
----------
load_val_data(processor, n=500, batch_size=8, seed=42)
-> torch.utils.data.DataLoader
Each batch yields a dict:
{
"pixel_values" : FloatTensor (B, 3, 384, 384),
"labels" : LongTensor (B, max_len), # reference caption ids
"captions" : list[str] # raw reference strings
}
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_03/step2_prepare_data.py
"""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BlipProcessor
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Dataset wrapper
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DATASET_ID = "nlphuji/flickr30k" # fallback if COCO unavailable
COCO_ID = "phiyodr/coco2017"
class COCOValDataset(Dataset):
"""
Wraps a HuggingFace dataset split into a torch Dataset.
Args:
hf_dataset : HuggingFace Dataset object with 'image' and 'captions' fields.
processor : BlipProcessor instance.
max_len : Maximum tokenization length for reference captions.
"""
def __init__(self, hf_dataset, processor: BlipProcessor, max_len: int = 64):
self.data = hf_dataset
self.processor = processor
self.max_len = max_len
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
example = self.data[idx]
image = example["image"].convert("RGB")
# Pick the first reference caption
captions = example.get("captions", example.get("caption", ["<no caption>"]))
if isinstance(captions, str):
captions = [captions]
caption = captions[0]
enc = self.processor(
images=image,
text=caption,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_len,
)
return {
"pixel_values": enc["pixel_values"].squeeze(0), # (3, H, W)
"labels": enc["input_ids"].squeeze(0), # (max_len,)
"caption": caption,
}
def _collate_fn(batch):
return {
"pixel_values": torch.stack([b["pixel_values"] for b in batch]),
"labels": torch.stack([b["labels"] for b in batch]),
"captions": [b["caption"] for b in batch],
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Public loader
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_val_data(
processor: BlipProcessor,
n: int = 500,
batch_size: int = 8,
seed: int = 42,
max_len: int = 64,
) -> DataLoader:
"""
Download and prepare n COCO validation images.
Falls back to Flickr30k if COCO is unavailable (e.g. firewall/proxy).
Args:
processor : BlipProcessor (from step1_load_model)
n : Number of validation images to use (default 500)
batch_size : DataLoader batch size
seed : Random seed for reproducible shuffle
max_len : Max caption token length for labels
Returns:
DataLoader that yields batches with keys:
pixel_values, labels, captions
"""
from datasets import load_dataset
print("=" * 60)
print(" Task 3 β Step 2: Prepare Validation Data")
print("=" * 60)
print(f" Target images : {n}")
print(f" Batch size : {batch_size}")
# ββ Try COCO first ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ds = None
try:
print(f" Loading dataset: {COCO_ID} ...")
raw = load_dataset(COCO_ID, split="validation", trust_remote_code=True)
ds = raw.shuffle(seed=seed).select(range(min(n, len(raw))))
print(f" β
COCO loaded ({len(ds)} images)")
except Exception as e:
print(f" β οΈ COCO unavailable ({e}). Falling back to Flickr30k β¦")
# ββ Fallback to Flickr30k βββββββββββββββββββββββββββββββββββββββββββββββββ
if ds is None:
raw = load_dataset(DATASET_ID, split="test", trust_remote_code=True)
ds = raw.shuffle(seed=seed).select(range(min(n, len(raw))))
print(f" β
Flickr30k loaded ({len(ds)} images)")
dataset = COCOValDataset(ds, processor, max_len=max_len)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=False,
collate_fn=_collate_fn,
)
print(f" Batches : {len(dataloader)}")
print("=" * 60)
return dataloader
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Standalone entrypoint
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
from step1_load_model import load_model
_, processor, _ = load_model()
loader = load_val_data(processor, n=500, batch_size=8)
# Peek at first batch
batch = next(iter(loader))
print(f"\nβ
DataLoader ready!")
print(f" pixel_values shape : {batch['pixel_values'].shape}")
print(f" labels shape : {batch['labels'].shape}")
print(f" Sample caption : {batch['captions'][0][:80]}")
print(f"\nImport in notebooks:")
print(" from task.task_03.step2_prepare_data import load_val_data")
|