Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| import glob | |
| import json | |
| import os | |
| from typing import Callable, Dict, List, Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torch.nn.utils.rnn import pad_sequence | |
| from virtex.data import transforms as T | |
| class ZeroShotDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_root: str = "datasets/inaturalist", | |
| split: str = "train", | |
| image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, | |
| label_map: str = None, | |
| tokenizer = None, | |
| model_dataset = 'redcaps', | |
| prompt_cls_sos = None, | |
| prompt_sos_eos = None | |
| ): | |
| self.data_root = data_root | |
| self.split = split | |
| self.label_map = json.load(open(label_map)) | |
| self.tokenizer = tokenizer | |
| self.image_transform = image_transform | |
| self.model_dataset = model_dataset | |
| self.prompt_cls_sos = prompt_cls_sos | |
| self.prompt_sos_eos = prompt_sos_eos | |
| im_id = 0 | |
| self.image_id_to_file_path = {} | |
| self.instances = [] | |
| for folder_name,labelname in self.label_map.items(): | |
| image_folder = self.data_root + self.split + folder_name + "/" | |
| for image_file in [x for x in os.listdir(image_folder) if x[-4:]=='.jpg']: | |
| path = image_folder + image_file | |
| self.image_id_to_file_path[im_id] = path | |
| self.instances.append((im_id,labelname[1])) | |
| im_id+=1 | |
| im_net_list = [x[0].replace('_',' ').lower() for x in sorted(self.label_map.values(),key=lambda x: x[1])] | |
| print(im_net_list) | |
| cls_token = [tokenizer.token_to_id("[CLS]")] | |
| sos_token = [tokenizer.token_to_id("[SOS]")] | |
| eos_token =[tokenizer.token_to_id("[EOS]")] | |
| a_an_dets = [ " an " if cat[0].lower() in ["a","e","i","o","u"] else " a " for cat in im_net_list ] | |
| imagenet_tensors = [cls_token | |
| +tokenizer.encode("i took a picture") | |
| +sos_token | |
| +tokenizer.encode("itap of "+a_an_dets[i]+im_net_list[i]) | |
| +eos_token | |
| for i in range(len(im_net_list))] | |
| imagenet_tensors_backward = [cls_token | |
| +tokenizer.encode("i took a picture") | |
| +eos_token | |
| +tokenizer.encode("itap of "+a_an_dets[i]+im_net_list[i])[::-1] | |
| +sos_token | |
| for i in range(len(im_net_list))] | |
| tensor_lengths = torch.tensor([len(x) for x in imagenet_tensors]) | |
| imagenet_tensors_forward = [torch.tensor(x) for x in imagenet_tensors] | |
| imagenet_tensors_backward = [torch.tensor(x) for x in imagenet_tensors_backward] | |
| imagenet_tensors_forward = pad_sequence(imagenet_tensors_forward,batch_first=True) | |
| imagenet_tensors_backward = pad_sequence(imagenet_tensors_backward,batch_first=True) | |
| print("imagenet_tensors_forward.shape: ", imagenet_tensors_forward.shape) | |
| print("imagenet_tensors_backward.shape: ", imagenet_tensors_backward.shape) | |
| print("tensor_lengths.shape: ", tensor_lengths.shape) | |
| self.imagenet_tensors_forward = imagenet_tensors_forward | |
| self.imagenet_tensors_backward = imagenet_tensors_backward | |
| self.tensor_lengths = tensor_lengths.long() | |
| def __len__(self): | |
| return len(self.instances) | |
| def __getitem__(self, idx: int): | |
| image_id, label = self.instances[idx] | |
| image_path = self.image_id_to_file_path[image_id] | |
| try: | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = self.image_transform(image=image)["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| except: | |
| print("$#%@#$%#image_path$@%:",image_path) | |
| image = np.random.rand(234, 325, 3) | |
| #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = self.image_transform(image=image)["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| return { | |
| "image": torch.tensor(image, dtype=torch.float), | |
| "label": torch.tensor(label, dtype=torch.long), | |
| "caption_tokens": self.imagenet_tensors_forward, | |
| "noitpac_tokens": self.imagenet_tensors_backward, | |
| "caption_lengths": self.tensor_lengths | |
| } | |
| def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
| return { | |
| "image": torch.stack([d["image"] for d in data], dim=0), | |
| "label": torch.stack([d["label"] for d in data], dim=0), | |
| "caption_tokens": data[0]['caption_tokens'], | |
| "noitpac_tokens": data[0]['noitpac_tokens'], | |
| "caption_lengths": data[0]['caption_lengths'] | |
| } |