Spaces:
Configuration error
Configuration error
| import os | |
| import shutil | |
| from ..config import Config | |
| from . import logger | |
| def checks(config: Config): | |
| save_dir = f"{config.run_dir}/{config.run_name}" | |
| if "tmp" in config.run_name: | |
| logger.print_warning("Using 'tmp' in run name. Wandb will not be used.") | |
| config.wandb = False | |
| if os.path.exists(save_dir) and "tmp" not in save_dir: | |
| if config.throw_exception_if_run_exists: | |
| raise FileExistsError(f"Folder {save_dir} exists, remove it or include 'tmp' in run name") | |
| logger.print() | |
| logger.print_warning(f"folder [magenta]{save_dir}[/] exists, remove it or include 'tmp' in run name") | |
| if config.remove_if_run_exists: | |
| logger.print_warning(f"Folder [magenta]{save_dir}[/] is removed") | |
| shutil.rmtree(str(save_dir)) | |
| else: | |
| logger.print("Enter [green bold]R[/] to replace") | |
| # Interactively ask | |
| key = input() | |
| if key not in ["R"]: | |
| logger.print_error("Aborted") | |
| exit() | |
| if key == "R": | |
| logger.print_warning(f"Folder [magenta]{save_dir}[/] is removed") | |
| shutil.rmtree(str(save_dir)) | |
| if config.binary_labels and config.num_classes != 2: | |
| raise ValueError("Binary labels is only supported for 2 classes") | |
| def get_files_from_dict_values(d: list[str] | dict[str, list[str]]): | |
| if isinstance(d, list): | |
| return d | |
| return [f for sublist in d.values() for f in sublist] | |
| trn_files = get_files_from_dict_values(config.trn_files) | |
| if not all(os.path.exists(f) for f in trn_files): | |
| not_found = [f for f in trn_files if not os.path.exists(f)] | |
| raise FileNotFoundError(f"Some train files are not found: {not_found}") | |
| val_files = get_files_from_dict_values(config.val_files) | |
| if not all(os.path.exists(f) for f in val_files): | |
| not_found = [f for f in val_files if not os.path.exists(f)] | |
| raise FileNotFoundError(f"Some val files are not found: {not_found}") | |
| tst_files = get_files_from_dict_values(config.tst_files) | |
| if not all(os.path.exists(f) for f in tst_files): | |
| not_found = [f for f in tst_files if not os.path.exists(f)] | |
| raise FileNotFoundError(f"Some test files are not found: {not_found}") | |