Spaces:
Build error
Build error
| from fake_face_detection.utils.compute_weights import compute_weights | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from glob import glob | |
| import numpy as np | |
| import torch | |
| import os | |
| class LionCheetahDataset(Dataset): | |
| def __init__(self, lion_path: str, cheetah_path: str, id_map: dict, transformer, **transformer_kwargs): | |
| # let us recuperate the transformer | |
| self.transformer = transformer | |
| # let us recuperate the transformer kwargs | |
| self.transformer_kwargs = transformer_kwargs | |
| # let us load the images | |
| lion_images = glob(os.path.join(lion_path, "*")) | |
| cheetah_images = glob(os.path.join(cheetah_path, "*")) | |
| # recuperate rgb images | |
| self.lion_images = [] | |
| self.cheetah_images = [] | |
| for lion in lion_images: | |
| try: | |
| with Image.open(lion) as img: | |
| # let us add a transformation on the images | |
| if self.transformer: | |
| image = self.transformer(img, **self.transformer_kwargs) | |
| self.lion_images.append(lion) | |
| except Exception as e: | |
| pass | |
| for cheetah in cheetah_images: | |
| try: | |
| with Image.open(cheetah) as img: | |
| # let us add a transformation on the images | |
| if self.transformer: | |
| image = self.transformer(img, **self.transformer_kwargs) | |
| self.cheetah_images.append(cheetah) | |
| except Exception as e: | |
| pass | |
| self.images = self.lion_images + self.cheetah_images | |
| # let us recuperate the labels | |
| self.lion_labels = [int(id_map['lion'])] * len(self.lion_images) | |
| self.cheetah_labels = [int(id_map['cheetah'])] * len(self.cheetah_images) | |
| self.labels = self.lion_labels + self.cheetah_labels | |
| # let us recuperate the weights | |
| self.weights = torch.from_numpy(compute_weights(self.labels)) | |
| # let us recuperate the length | |
| self.length = len(self.labels) | |
| def __getitem__(self, index): | |
| # let us recuperate an image | |
| image = self.images[index] | |
| with Image.open(image) as img: | |
| # let us recuperate a label | |
| label = self.labels[index] | |
| # let us add a transformation on the images | |
| if self.transformer: | |
| image = self.transformer(img, **self.transformer_kwargs) | |
| # let us add the label inside the obtained dictionary | |
| image['labels'] = label | |
| return image | |
| def __len__(self): | |
| return self.length | |