File size: 3,215 Bytes
d31183e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from build_vocab import Vocabulary
import torchvision.transforms as transforms
from PIL import Image

class Flickr8kDataset(Dataset):
    def __init__(self, hf_dataset, vocab, transform=None):
        self.dataset = hf_dataset
        self.vocab = vocab
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        item = self.dataset[index]
        image = item["image"]
        
        # Handle different column names for captions in various HF datasets
        caption_keys = ["caption", "captions", "text", "text_en", "caption_0"]
        caption = None
        for key in caption_keys:
            if key in item:
                caption = item[key]
                break
                
        # If the dataset provides a list of captions per image, take the first one
        if isinstance(caption, list):
            caption = caption[0]
            
        # Convert grayscale to RGB if needed
        if image.mode != "RGB":
            image = image.convert("RGB")
            
        if self.transform is not None:
            image = self.transform(image)
            
        # Add <SOS> and <EOS> tokens
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(str(caption))
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return image, torch.tensor(numericalized_caption)

class CapsCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
        
        return imgs, targets

def get_loader(dataset_name="jxie/flickr8k", split="train", transform=None, batch_size=32, num_workers=0, shuffle=True, vocab_threshold=5, vocab=None):
    # jxie/flickr8k is a common HF dataset for Flickr8k
    hf_dataset = load_dataset(dataset_name, split=split)
    
    if vocab is None:
        vocab = Vocabulary(vocab_threshold)
        captions = []
        # Build vocab
        for item in hf_dataset:
            caption_keys = ["caption", "captions", "text", "text_en", "caption_0"]
            for key in caption_keys:
                if key in item:
                    caps = item[key]
                    if isinstance(caps, list):
                        captions.extend([str(c) for c in caps])
                    else:
                        captions.append(str(caps))
                    break
        vocab.build_vocabulary(captions)
        
    dataset = Flickr8kDataset(hf_dataset, vocab, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        collate_fn=CapsCollate(pad_idx=pad_idx)
    )
    
    return loader, dataset