Spaces:
Build error
Build error
| """Dataset Split Utils. | |
| This module contains function in regards to splitting normal images in training set, | |
| and creating validation sets from test sets. | |
| These function are useful | |
| - when the test set does not contain any normal images. | |
| - when the dataset doesn't have a validation set. | |
| """ | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| import random | |
| from pandas.core.frame import DataFrame | |
| def split_normal_images_in_train_set( | |
| samples: DataFrame, split_ratio: float = 0.1, seed: int = 0, normal_label: str = "good" | |
| ) -> DataFrame: | |
| """Split normal images in train set. | |
| This function splits the normal images in training set and assigns the | |
| values to the test set. This is particularly useful especially when the | |
| test set does not contain any normal images. | |
| This is important because when the test set doesn't have any normal images, | |
| AUC computation fails due to having single class. | |
| Args: | |
| samples (DataFrame): Dataframe containing dataset info such as filenames, splits etc. | |
| split_ratio (float, optional): Train-Test normal image split ratio. Defaults to 0.1. | |
| seed (int, optional): Random seed to ensure reproducibility. Defaults to 0. | |
| normal_label (str): Name of the normal label. For MVTec AD, for instance, this is normal_label. | |
| Returns: | |
| DataFrame: Output dataframe where the part of the training set is assigned to test set. | |
| """ | |
| if seed > 0: | |
| random.seed(seed) | |
| normal_train_image_indices = samples.index[(samples.split == "train") & (samples.label == normal_label)].to_list() | |
| num_normal_train_images = len(normal_train_image_indices) | |
| num_normal_valid_images = int(num_normal_train_images * split_ratio) | |
| indices_to_split_from_train_set = random.sample(population=normal_train_image_indices, k=num_normal_valid_images) | |
| samples.loc[indices_to_split_from_train_set, "split"] = "test" | |
| return samples | |
| def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, normal_label: str = "good") -> DataFrame: | |
| """Craete Validation Set from Test Set. | |
| This function creates a validation set from test set by splitting both | |
| normal and abnormal samples to two. | |
| Args: | |
| samples (DataFrame): Dataframe containing dataset info such as filenames, splits etc. | |
| seed (int, optional): Random seed to ensure reproducibility. Defaults to 0. | |
| normal_label (str): Name of the normal label. For MVTec AD, for instance, this is normal_label. | |
| """ | |
| if seed > 0: | |
| random.seed(seed) | |
| # Split normal images. | |
| normal_test_image_indices = samples.index[(samples.split == "test") & (samples.label == normal_label)].to_list() | |
| num_normal_valid_images = len(normal_test_image_indices) // 2 | |
| indices_to_sample = random.sample(population=normal_test_image_indices, k=num_normal_valid_images) | |
| samples.loc[indices_to_sample, "split"] = "val" | |
| # Split abnormal images. | |
| abnormal_test_image_indices = samples.index[(samples.split == "test") & (samples.label != normal_label)].to_list() | |
| num_abnormal_valid_images = len(abnormal_test_image_indices) // 2 | |
| indices_to_sample = random.sample(population=abnormal_test_image_indices, k=num_abnormal_valid_images) | |
| samples.loc[indices_to_sample, "split"] = "val" | |
| return samples | |