import os import json import random from PIL import Image from torch.utils.data import Dataset import torch class CaptionDataset(Dataset): def __init__( self, json_path, image_dir, w2i, tokenizer: callable, split='train', transform=None, max_len=30, train_num_caption=1, debug=False, use_subword=False, sp_model_path="tokenizer.model" ): with open(json_path, 'r') as f: self.data = json.load(f) # 디버깅용 if debug: self.data= self.data[:10] if split == "val": self.is_val = True else: self.is_val = False self.image_dir = image_dir self.w2i = w2i self.transform = transform self.max_len = max_len self.tokenizer = tokenizer self.train_num_caption = train_num_caption self.use_subword = use_subword if self.use_subword: import sentencepiece as spm self.sp = spm.SentencePieceProcessor() self.sp.load(sp_model_path) def __len__(self): return len(self.data) def encode_caption(self, caption): if self.use_subword: words = self.sp.encode(caption.lower(), out_type=str) tokens = ( [self.w2i[""]] + [self.w2i.get(w, self.w2i[""]) for w in words] + [self.w2i[""]] ) else: words = self.tokenizer(caption) tokens = ( [self.w2i[""]] + [self.w2i.get(w, self.w2i[""]) for w in words] + [self.w2i[""]] ) # truncation if len(tokens) > self.max_len: tokens = (tokens[:self.max_len - 1]) tokens.append(self.w2i[""]) else: tokens += ([self.w2i[""]] * (self.max_len - len(tokens))) return torch.tensor(tokens, dtype=torch.long) def __getitem__(self, index): data = self.data[index] file_name = data["file_name"] image_path = os.path.join(self.image_dir, file_name) image = Image.open(image_path).convert('RGB') if self.transform: image = self.transform(image) captions = data["captions"] captions = captions[:5] # 캡션 5개 초과시 5개까지만 씀 while len(captions) < 5: # 캡션 5개 보다 부족할 시 마지막 캡션 복제해서 씀 captions.append(captions[-1]) # validation if self.is_val: caption = random.choice(captions) tokens = (self.encode_caption(caption)) return image, tokens, captions, file_name # train selected_captions = (random.sample(captions, k=self.train_num_caption)) images = [] token_list = [] for caption in selected_captions: images.append(image) token_list.append(self.encode_caption(caption)) images = torch.stack(images) tokens = torch.stack(token_list) return images, tokens