Spaces:
Sleeping
Sleeping
| import os | |
| import ast | |
| import sys | |
| import shutil | |
| import glob | |
| import functools | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from src.dataset.dataset import SimpleIterDataset | |
| from src.utils.import_tools import import_module | |
| from src.dataset.functions_graph import graph_batch_func | |
| def set_gpus(args): | |
| if args.gpus: | |
| gpus = [int(i) for i in args.gpus.split(",")] | |
| dev = torch.device(gpus[0]) | |
| print("Using GPUs:", gpus) | |
| else: | |
| print("No GPUs flag provided - Setting GPUs to [0]") | |
| gpus = [0] | |
| dev = torch.device(gpus[0]) | |
| raise Exception("Please provide GPU number") | |
| return gpus, dev | |
| def get_gpu_dev(args): | |
| if args.gpus != "": | |
| accelerator = "gpu" | |
| devices = args.gpus | |
| else: | |
| accelerator = 0 | |
| devices = 0 | |
| return accelerator, devices | |
| # TODO change this to use it from config file | |
| def model_setup(args, data_config): | |
| """ | |
| Loads the model | |
| :param args: | |
| :param data_config: | |
| :return: model, model_info, network_module | |
| """ | |
| network_module = import_module(args.network_config, name="_network_module") | |
| if args.gpus: | |
| gpus = [int(i) for i in args.gpus.split(",")] # ? | |
| dev = torch.device(gpus[0]) | |
| print("using GPUs:", gpus) | |
| else: | |
| gpus = None | |
| local_rank = 0 | |
| dev = torch.device("cpu") | |
| model, model_info = network_module.get_model( | |
| data_config, args=args, dev=dev | |
| ) | |
| return model.mod | |
| def get_samples_steps_per_epoch(args): | |
| if args.samples_per_epoch is not None: | |
| if args.steps_per_epoch is None: | |
| args.steps_per_epoch = args.samples_per_epoch // args.batch_size | |
| else: | |
| raise RuntimeError( | |
| "Please use either `--steps-per-epoch` or `--samples-per-epoch`, but not both!" | |
| ) | |
| if args.samples_per_epoch_val is not None: | |
| if args.steps_per_epoch_val is None: | |
| args.steps_per_epoch_val = args.samples_per_epoch_val // args.batch_size | |
| else: | |
| raise RuntimeError( | |
| "Please use either `--steps-per-epoch-val` or `--samples-per-epoch-val`, but not both!" | |
| ) | |
| if args.steps_per_epoch_val is None and args.steps_per_epoch is not None: | |
| args.steps_per_epoch_val = round( | |
| args.steps_per_epoch * (1 - args.train_val_split) / args.train_val_split | |
| ) | |
| if args.steps_per_epoch_val is not None and args.steps_per_epoch_val < 0: | |
| args.steps_per_epoch_val = None | |
| return args | |
| def to_filelist(args, mode="train"): | |
| if mode == "train": | |
| flist = args.data_train | |
| elif mode == "val": | |
| flist = args.data_val | |
| else: | |
| raise NotImplementedError("Invalid mode %s" % mode) | |
| # keyword-based: 'a:/path/to/a b:/path/to/b' | |
| file_dict = {} | |
| for f in flist: | |
| if ":" in f: | |
| name, fp = f.split(":") | |
| else: | |
| name, fp = "_", f | |
| files = glob.glob(fp) | |
| if name in file_dict: | |
| file_dict[name] += files | |
| else: | |
| file_dict[name] = files | |
| # sort files | |
| for name, files in file_dict.items(): | |
| file_dict[name] = sorted(files) | |
| if args.local_rank is not None: | |
| if mode == "train": | |
| gpus_list, _ = set_gpus(args) | |
| local_world_size = len(gpus_list) # int(os.environ['LOCAL_WORLD_SIZE']) | |
| new_file_dict = {} | |
| for name, files in file_dict.items(): | |
| new_files = files[args.local_rank :: local_world_size] | |
| assert len(new_files) > 0 | |
| np.random.shuffle(new_files) | |
| new_file_dict[name] = new_files | |
| file_dict = new_file_dict | |
| print(args.local_rank, len(file_dict["_"])) | |
| filelist = sum(file_dict.values(), []) | |
| assert len(filelist) == len(set(filelist)) | |
| return file_dict, filelist | |
| def train_load(args): | |
| """ | |
| Loads the training data. | |
| :param args: | |
| :return: train_loader, val_loader, data_config, train_inputs | |
| """ | |
| train_file_dict, train_files = to_filelist(args, "train") | |
| if args.data_val: | |
| val_file_dict, val_files = to_filelist(args, "val") | |
| train_range = val_range = (0, 1) | |
| else: | |
| val_file_dict, val_files = train_file_dict, train_files | |
| train_range = (0, args.train_val_split) | |
| val_range = (args.train_val_split, 1) | |
| train_data = SimpleIterDataset( | |
| train_file_dict, | |
| args.data_config, | |
| for_training=True, | |
| extra_selection=None, | |
| remake_weights=False, | |
| load_range_and_fraction=(train_range, args.data_fraction), | |
| file_fraction=args.file_fraction, | |
| fetch_by_files=args.fetch_by_files, | |
| fetch_step=args.fetch_step, | |
| infinity_mode=args.steps_per_epoch is not None, | |
| name="train" + ("" if args.local_rank is None else "_rank%d" % args.local_rank), | |
| args_parse=args | |
| ) | |
| val_data = SimpleIterDataset( | |
| val_file_dict, | |
| args.data_config, | |
| for_training=True, | |
| extra_selection=None, | |
| load_range_and_fraction=(val_range, args.data_fraction), | |
| file_fraction=args.file_fraction, | |
| fetch_by_files=args.fetch_by_files, | |
| fetch_step=args.fetch_step, | |
| infinity_mode=args.steps_per_epoch_val is not None, | |
| name="val" + ("" if args.local_rank is None else "_rank%d" % args.local_rank), | |
| args_parse=args | |
| ) | |
| collator_func = graph_batch_func | |
| # train_data_arg = train_data | |
| # val_data_arg = val_data | |
| # if args.train_cap == 1: | |
| # train_data_arg = [next(iter(train_data_arg))] | |
| # if args.val_cap == 1: | |
| # val_data_arg = [next(iter(val_data_arg))] | |
| prefetch_factor = None | |
| if args.num_workers > 0: | |
| prefetch_factor = args.prefetch_factor | |
| train_loader = DataLoader( | |
| train_data, | |
| batch_size=args.batch_size, | |
| drop_last=True, | |
| pin_memory=True, | |
| num_workers=min(args.num_workers, int(len(train_files) * args.file_fraction)), | |
| collate_fn=collator_func, | |
| persistent_workers=False, | |
| prefetch_factor=prefetch_factor | |
| ) | |
| val_loader = DataLoader( | |
| val_data, | |
| batch_size=args.batch_size, | |
| drop_last=True, | |
| pin_memory=True, | |
| collate_fn=collator_func, | |
| num_workers=min(args.num_workers, int(len(val_files) * args.file_fraction)), | |
| persistent_workers=args.num_workers > 0 | |
| and args.steps_per_epoch_val is not None, | |
| prefetch_factor=prefetch_factor | |
| ) | |
| data_config = 0 #train_data.config | |
| train_input_names = 0 #train_data.config.input_names | |
| train_label_names = 0 # train_data.config.label_names | |
| return train_loader, val_loader, data_config, train_input_names | |
| def test_load(args): | |
| """ | |
| Loads the test data. | |
| :param args: | |
| :return: test_loaders, data_config | |
| """ | |
| # keyword-based --data-test: 'a:/path/to/a b:/path/to/b' | |
| # split --data-test: 'a%10:/path/to/a/*' | |
| file_dict = {} | |
| split_dict = {} | |
| for f in args.data_test: | |
| if ":" in f: | |
| name, fp = f.split(":") | |
| if "%" in name: | |
| name, split = name.split("%") | |
| split_dict[name] = int(split) | |
| else: | |
| name, fp = "", f | |
| files = glob.glob(fp) | |
| if name in file_dict: | |
| file_dict[name] += files | |
| else: | |
| file_dict[name] = files | |
| # sort files | |
| for name, files in file_dict.items(): | |
| file_dict[name] = sorted(files) | |
| # apply splitting | |
| for name, split in split_dict.items(): | |
| files = file_dict.pop(name) | |
| for i in range((len(files) + split - 1) // split): | |
| file_dict[f"{name}_{i}"] = files[i * split : (i + 1) * split] | |
| def get_test_loader(name): | |
| filelist = file_dict[name] | |
| num_workers = min(args.num_workers, len(filelist)) | |
| test_data = SimpleIterDataset( | |
| {name: filelist}, | |
| args.data_config, | |
| for_training=False, | |
| extra_selection=None, | |
| load_range_and_fraction=((0, 1), args.data_fraction), | |
| fetch_by_files=True, | |
| fetch_step=1, | |
| name="test_" + name, | |
| args_parse=args | |
| ) | |
| test_loader = DataLoader( | |
| test_data, | |
| num_workers=num_workers, | |
| batch_size=args.batch_size, | |
| drop_last=False, | |
| pin_memory=True, | |
| collate_fn=graph_batch_func, | |
| ) | |
| return test_loader | |
| test_loaders = { | |
| name: functools.partial(get_test_loader, name) for name in file_dict | |
| } | |
| #data_config = SimpleIterDataset({}, args.data_config, for_training=False).config | |
| data_config = 0 | |
| return test_loaders, data_config | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.mod.parameters() if p.requires_grad) | |