Spaces:
Configuration error
Configuration error
File size: 2,305 Bytes
c29babb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import os
import shutil
from ..config import Config
from . import logger
def checks(config: Config):
save_dir = f"{config.run_dir}/{config.run_name}"
if "tmp" in config.run_name:
logger.print_warning("Using 'tmp' in run name. Wandb will not be used.")
config.wandb = False
if os.path.exists(save_dir) and "tmp" not in save_dir:
if config.throw_exception_if_run_exists:
raise FileExistsError(f"Folder {save_dir} exists, remove it or include 'tmp' in run name")
logger.print()
logger.print_warning(f"folder [magenta]{save_dir}[/] exists, remove it or include 'tmp' in run name")
if config.remove_if_run_exists:
logger.print_warning(f"Folder [magenta]{save_dir}[/] is removed")
shutil.rmtree(str(save_dir))
else:
logger.print("Enter [green bold]R[/] to replace")
# Interactively ask
key = input()
if key not in ["R"]:
logger.print_error("Aborted")
exit()
if key == "R":
logger.print_warning(f"Folder [magenta]{save_dir}[/] is removed")
shutil.rmtree(str(save_dir))
if config.binary_labels and config.num_classes != 2:
raise ValueError("Binary labels is only supported for 2 classes")
def get_files_from_dict_values(d: list[str] | dict[str, list[str]]):
if isinstance(d, list):
return d
return [f for sublist in d.values() for f in sublist]
trn_files = get_files_from_dict_values(config.trn_files)
if not all(os.path.exists(f) for f in trn_files):
not_found = [f for f in trn_files if not os.path.exists(f)]
raise FileNotFoundError(f"Some train files are not found: {not_found}")
val_files = get_files_from_dict_values(config.val_files)
if not all(os.path.exists(f) for f in val_files):
not_found = [f for f in val_files if not os.path.exists(f)]
raise FileNotFoundError(f"Some val files are not found: {not_found}")
tst_files = get_files_from_dict_values(config.tst_files)
if not all(os.path.exists(f) for f in tst_files):
not_found = [f for f in tst_files if not os.path.exists(f)]
raise FileNotFoundError(f"Some test files are not found: {not_found}")
|