Spaces:
Running
Running
| import torch | |
| from PIL import Image | |
| import pandas as pd | |
| class RetrievalDataset(torch.utils.data.Dataset): | |
| def __init__(self, img_dir_path: str, annotations_file_path: str, split: str, transform=None, tokenizer=None) -> None: | |
| self.img_dir_path = img_dir_path | |
| self.transform = transform | |
| self.tokenizer = tokenizer | |
| self.split = split | |
| self.annotations = self.split_data( | |
| self.convert_image_names_to_path( | |
| pd.read_csv(annotations_file_path) | |
| ) | |
| ) | |
| def __len__(self) -> int: | |
| return len(self.annotations) | |
| def __getitem__(self, idx: int) -> tuple: | |
| query_img_path = self.annotations.iloc[idx]['query_image'] | |
| query_text = self.annotations.iloc[idx]['query_text'] | |
| target_img_path = self.annotations.iloc[idx]['target_image'] | |
| query_img = Image.open(query_img_path).convert('RGB') | |
| target_img = Image.open(target_img_path).convert('RGB') | |
| # query_img = torchvision.io.read_image(path=query_img_path, mode=torchvision.io.image.ImageReadMode.RGB) | |
| # target_img = torchvision.io.read_image(path=target_img_path, mode=torchvision.io.image.ImageReadMode.RGB) | |
| if self.transform: | |
| query_img = self.transform(query_img) | |
| target_img = self.transform(target_img) | |
| if self.tokenizer: | |
| query_text = self.tokenizer(query_text).squeeze(0) | |
| return query_img, query_text, target_img, self.annotations.iloc[idx]['query_text'] | |
| def split_data(self, annotations): | |
| shuffled_df = annotations.sample(frac=1, random_state=42).reset_index(drop=True) | |
| if self.split == "test": | |
| return shuffled_df # sample test set | |
| if self.split == "train": | |
| return shuffled_df.iloc[:int(0.9 * len(shuffled_df))] # train set | |
| if self.split == "validation": | |
| return shuffled_df.iloc[int(0.9 * len(shuffled_df)):] # validation set | |
| raise Exception("split is not valid") | |
| def load_queries(self): | |
| return self.annotations | |
| def load_database(self): | |
| return pd.DataFrame({'target_image': self.annotations["target_image"].unique()}) | |
| def convert_image_names_to_path(self, df): | |
| df["query_image"] = self.img_dir_path + "/" + df["query_image"] | |
| df["target_image"] = self.img_dir_path + "/" + df["target_image"] | |
| return df |