Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import random | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| import torch | |
| class CaptionDataset(Dataset): | |
| def __init__( | |
| self, | |
| json_path, | |
| image_dir, | |
| w2i, | |
| tokenizer: callable, | |
| split='train', | |
| transform=None, | |
| max_len=30, | |
| train_num_caption=1, | |
| debug=False, | |
| use_subword=False, | |
| sp_model_path="tokenizer.model" | |
| ): | |
| with open(json_path, 'r') as f: | |
| self.data = json.load(f) | |
| # ๋๋ฒ๊น ์ฉ | |
| if debug: | |
| self.data= self.data[:10] | |
| if split == "val": | |
| self.is_val = True | |
| else: | |
| self.is_val = False | |
| self.image_dir = image_dir | |
| self.w2i = w2i | |
| self.transform = transform | |
| self.max_len = max_len | |
| self.tokenizer = tokenizer | |
| self.train_num_caption = train_num_caption | |
| self.use_subword = use_subword | |
| if self.use_subword: | |
| import sentencepiece as spm | |
| self.sp = spm.SentencePieceProcessor() | |
| self.sp.load(sp_model_path) | |
| def __len__(self): | |
| return len(self.data) | |
| def encode_caption(self, caption): | |
| if self.use_subword: | |
| words = self.sp.encode(caption.lower(), out_type=str) | |
| tokens = ( | |
| [self.w2i["<sos>"]] + | |
| [self.w2i.get(w, self.w2i["<unk>"]) for w in words] + | |
| [self.w2i["<eos>"]] | |
| ) | |
| else: | |
| words = self.tokenizer(caption) | |
| tokens = ( | |
| [self.w2i["<sos>"]] + | |
| [self.w2i.get(w, self.w2i["<unk>"]) for w in words] + | |
| [self.w2i["<eos>"]] | |
| ) | |
| # truncation | |
| if len(tokens) > self.max_len: | |
| tokens = (tokens[:self.max_len - 1]) | |
| tokens.append(self.w2i["<eos>"]) | |
| else: | |
| tokens += ([self.w2i["<pad>"]] * (self.max_len - len(tokens))) | |
| return torch.tensor(tokens, dtype=torch.long) | |
| def __getitem__(self, index): | |
| data = self.data[index] | |
| file_name = data["file_name"] | |
| image_path = os.path.join(self.image_dir, file_name) | |
| image = Image.open(image_path).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| captions = data["captions"] | |
| captions = captions[:5] # ์บก์ 5๊ฐ ์ด๊ณผ์ 5๊ฐ๊น์ง๋ง ์ | |
| while len(captions) < 5: # ์บก์ 5๊ฐ ๋ณด๋ค ๋ถ์กฑํ ์ ๋ง์ง๋ง ์บก์ ๋ณต์ ํด์ ์ | |
| captions.append(captions[-1]) | |
| # validation | |
| if self.is_val: | |
| caption = random.choice(captions) | |
| tokens = (self.encode_caption(caption)) | |
| return image, tokens, captions, file_name | |
| # train | |
| selected_captions = (random.sample(captions, k=self.train_num_caption)) | |
| images = [] | |
| token_list = [] | |
| for caption in selected_captions: | |
| images.append(image) | |
| token_list.append(self.encode_caption(caption)) | |
| images = torch.stack(images) | |
| tokens = torch.stack(token_list) | |
| return images, tokens | |