Spaces:
Runtime error
Runtime error
| import pickle | |
| import random | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| from .points_sampler import MultiPointSampler | |
| from .sample import DSample | |
| class ISDataset(torch.utils.data.dataset.Dataset): | |
| def __init__( | |
| self, | |
| augmentator=None, | |
| points_sampler=MultiPointSampler(max_num_points=12), | |
| min_object_area=0, | |
| keep_background_prob=0.0, | |
| with_image_info=False, | |
| samples_scores_path=None, | |
| samples_scores_gamma=1.0, | |
| epoch_len=-1, | |
| ): | |
| super(ISDataset, self).__init__() | |
| self.epoch_len = epoch_len | |
| self.augmentator = augmentator | |
| self.min_object_area = min_object_area | |
| self.keep_background_prob = keep_background_prob | |
| self.points_sampler = points_sampler | |
| self.with_image_info = with_image_info | |
| self.samples_precomputed_scores = self._load_samples_scores( | |
| samples_scores_path, samples_scores_gamma | |
| ) | |
| self.to_tensor = transforms.ToTensor() | |
| self.dataset_samples = None | |
| def __getitem__(self, index): | |
| if self.samples_precomputed_scores is not None: | |
| index = np.random.choice( | |
| self.samples_precomputed_scores["indices"], | |
| p=self.samples_precomputed_scores["probs"], | |
| ) | |
| else: | |
| if self.epoch_len > 0: | |
| index = random.randrange(0, len(self.dataset_samples)) | |
| sample = self.get_sample(index) | |
| sample = self.augment_sample(sample) | |
| sample.remove_small_objects(self.min_object_area) | |
| self.points_sampler.sample_object(sample) | |
| points = np.array(self.points_sampler.sample_points()) | |
| mask = self.points_sampler.selected_mask | |
| output = { | |
| "images": self.to_tensor(sample.image), | |
| "points": points.astype(np.float32), | |
| "instances": mask, | |
| } | |
| if self.with_image_info: | |
| output["image_info"] = sample.sample_id | |
| return output | |
| def augment_sample(self, sample) -> DSample: | |
| if self.augmentator is None: | |
| return sample | |
| valid_augmentation = False | |
| while not valid_augmentation: | |
| sample.augment(self.augmentator) | |
| keep_sample = ( | |
| self.keep_background_prob < 0.0 | |
| or random.random() < self.keep_background_prob | |
| ) | |
| valid_augmentation = len(sample) > 0 or keep_sample | |
| return sample | |
| def get_sample(self, index) -> DSample: | |
| raise NotImplementedError | |
| def __len__(self): | |
| if self.epoch_len > 0: | |
| return self.epoch_len | |
| else: | |
| return self.get_samples_number() | |
| def get_samples_number(self): | |
| return len(self.dataset_samples) | |
| def _load_samples_scores(samples_scores_path, samples_scores_gamma): | |
| if samples_scores_path is None: | |
| return None | |
| with open(samples_scores_path, "rb") as f: | |
| images_scores = pickle.load(f) | |
| probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) | |
| probs /= probs.sum() | |
| samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs} | |
| print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}") | |
| return samples_scores | |