import argparse import time import datetime import yaml import os start_time = time.time() import dgl import torch import torch.nn as nn import sys file_path = os.getcwd() sys.path.append(file_path) import root_gnn_base.batched_dataset as datasets from root_gnn_base import utils import root_gnn_base.custom_scheduler as lr_utils from models import GCN import numpy as np from sklearn.metrics import roc_auc_score import resource import gc import torch.distributed as dist import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP print("import time: {:.4f} s".format(time.time() - start_time)) def mem(): print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024} GB') def gpu_mem(): print() print('GPU Memory Usage:') sum = 0 # for obj in gc.get_objects(): # try: # if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # print(obj.numel() if len(obj.size()) > 0 else 0, type(obj), obj.size()) # sum += obj.numel() if len(obj.size()) > 0 else 0 # except: # pass print(f'Current GPU memory usage: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024} GB') # print(f'Current GPU cache usage: {torch.cuda.memory_cached() / 1024 / 1024 / 1024} GB') # print(f'Current GPU max memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024} GB') # print(f'Current GPU max cache usage: {torch.cuda.max_memory_cached() / 1024 / 1024 / 1024} GB') # print(f'Numel in current tensors: {sum}') mem() ## epoch stores the epoch number I want to evaluate the model at def evaluate(val_loaders, model, config, device, epoch = -1): print("Evaluating") if (epoch != -1) : print(f"Evalulating at epoch {epoch}") last_ep, checkpoint = utils.get_specific_epoch(config, epoch, from_ryan=False) print(f"Evaluating at epoch = {last_ep}") else: starting_epoch = 0 last_ep, checkpoint = utils.get_last_epoch(config) if checkpoint != None: ep = last_ep state_dict = checkpoint['model_state_dict'] new_state_dict = {} for k, v in state_dict.items(): new_key = k.replace('module.', '') new_state_dict[new_key] = v model.load_state_dict(new_state_dict) starting_epoch = checkpoint['epoch'] + 1 print(f"Loaded epoch {checkpoint['epoch']} from checkpoint") if 'Loss' not in config: loss_fcn = nn.BCEWithLogitsLoss(reduction='none') else: loss_fcn = utils.buildFromConfig(config['Loss'], {'reduction': 'none'}) if len(val_loaders) == 0: return "No validation data" start = time.time() scores = [] labels = [] weights = [] before_decoder = [] after_decoder = [] tracking = [] batch_size = config["Training"]["batch_size"] batch_limit = int(np.ceil(1e5 / batch_size)) model.eval() with torch.no_grad(): for loader in val_loaders: batch_count = 0 for batch, label, track, global_feats in loader: #Don't use compiled model for testing since we can't control the batch size. #We could before, but it assumes each dataset has the same number of batches... before_global_decoder, after_global_decoder, after_classify = model.representation(batch.to(device), global_feats.to(device)) scores.append(after_classify.to("cpu")) before_decoder.append(before_global_decoder.to("cpu")) after_decoder.append(after_global_decoder.to("cpu")) labels.append(label.to("cpu")) weights.append(track[:,1].to("cpu")) tracking.append(track.to("cpu")) batch_count += 1 if batch_count >= batch_limit: break if scores == []: #If validation set is empty. return logits = torch.concatenate(scores) scores = torch.sigmoid(logits) labels = torch.concatenate(labels) weights = torch.concatenate(weights) before_decoder = torch.concatenate(before_decoder) after_decoder = torch.concatenate(after_decoder) tracking = torch.concatenate(tracking) logits = logits.to("cpu").numpy() scores = scores.to("cpu").numpy() labels = labels.to("cpu").numpy() before_decoder = before_decoder.to("cpu").numpy() after_decoder = after_decoder.to("cpu").numpy() tracking = tracking.to("cpu").numpy() # Save the NumPy arrays to a .npz file outfile = f"{config['Training_Directory']}/evaluation_{epoch}.npz" np.savez(outfile, logits=logits, scores=scores, labels=labels, before_decoder=before_decoder, after_decoder=after_decoder, tracking=tracking) print(f"saved scores to {outfile}") return def train(train_loaders, test_loaders, model, device, config, args, rank): nocompile = args.nocompile restart = args.restart # define train/val samples, loss function and optimizer if 'Loss' not in config: loss_fcn = nn.BCEWithLogitsLoss(reduction='none') finish_fn = torch.nn.Sigmoid() else: loss_fcn = utils.buildFromConfig(config['Loss'], {'reduction':'none'}) finish_fn = utils.buildFromConfig(config['Loss']['finish']) optimizer = torch.optim.Adam(model.parameters(), lr=config['Training']['learning_rate']) if 'gamma' in config['Training']: gamma = config['Training']['gamma'] else: gamma = 1 if 'dynamic_lr' in config['Training']: factor = config['Training']['dynamic_lr']['factor'] patience = config['Training']['dynamic_lr']['patience'] else: factor = 1 patience = 1 early_termination = utils.EarlyStop() if 'early_termination' in config['Training']: early_termination.patience = config['Training']['early_termination']['patience'] early_termination.threshold = config['Training']['early_termination']['threshold'] early_termination.mode = config['Training']['early_termination']['mode'] scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = gamma) #scheduler_reset = custom_scheduler.Dynamic_LR(optimizer, 'max', factor = factor, patience = patience) custom_scheduler = None if ('custom_scheduler' in config['Training']): run_time_args = {} scheduler_class = config['Training']['custom_scheduler']['class'] if (scheduler_class == 'Dynamic_LR' or scheduler_class == 'Dynamic_LR_AND_Partial_Reset' or scheduler_class == 'Dynamic_LR_AND_Full_Reset'): run_time_args={'optimizer': optimizer} custom_scheduler = utils.buildFromConfig(config['Training']['custom_scheduler'], run_time_args=run_time_args) starting_epoch = 0 if not restart: last_ep, checkpoint = utils.get_last_epoch(config) if checkpoint != None: ep = starting_epoch - 1 if nocompile: new_state_dict = {} for k, v in checkpoint['model_state_dict'].items(): new_key = k.replace('module.', '') new_state_dict[new_key] = v checkpoint['model_state_dict'] = new_state_dict if (args.multinode or args.multigpu): new_state_dict = {} for k, v in checkpoint['model_state_dict'].items(): new_key = 'module.' + k new_state_dict[new_key] = v checkpoint['model_state_dict'] = new_state_dict model.load_state_dict(checkpoint['model_state_dict']) else: model._orig_mod.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) starting_epoch = checkpoint['epoch'] + 1 if 'early_stop' in checkpoint: early_termination = utils.EarlyStop.load_from_dict(checkpoint['early_stop']) print(early_termination.to_str()) print("EarlyStop state restored successfully.") if early_termination.should_stop: print(f"Early Termination at Epoch {epoch}") return else: print("'early_stop' not found in checkpoint. Initializing a new EarlyStop instance.") early_termination = utils.EarlyStop() print(f"Loaded epoch {checkpoint['epoch']} from checkpoint") log = open(config['Training_Directory'] + '/training.log', 'a', buffering=1) else: log = open(config['Training_Directory'] + '/training.log', 'w', buffering=1) train_cyclers = [] for loader in train_loaders: train_cyclers.append(utils.cycler((loader))) if args.savecache: max_batch = [None,] * len(train_loaders) for dset_i, loader in enumerate(train_loaders): mbs = 0 for batch_i, batch in enumerate(loader): if batch[0].num_nodes() > mbs: mbs = batch[0].num_nodes() max_batch[dset_i] = batch[0] print(f'Max batch size for dataset {dset_i}: {mbs}') big_batch = dgl.batch(max_batch).to(device) with torch.no_grad(): model(big_batch) cumulative_times = [0,0,0,0,0] log.write(f'Training {config["Training_Name"]} {datetime.datetime.now()} \n') print(f"Starting training for {config['Training']['epochs']} epochs") if hasattr(train_loaders[0].dataset, 'padding_mode'): is_padded = train_loaders[0].dataset.padding_mode != 'NONE' if (train_loaders[0].dataset.padding_mode == 'NODE'): is_padded = False else: is_padded = False lr_utils.print_LR(optimizer) # torch.save({ # 'epoch': 0, # 'model_state_dict': model.state_dict(), # 'optimizer_state_dict': optimizer.state_dict(), # }, os.path.join(config['Training_Directory'], f"model_epoch_{0}.pt")) # exit() # training loop # gpu_mem() for epoch in range(starting_epoch, config['Training']['epochs']): start = time.time() run = start if (args.profile): if (epoch == 0): torch.cuda.cudart().cudaProfilerStart() torch.cuda.nvtx.range_push("Epoch Start") if (args.multigpu or args.multinode): dist.barrier() if (epoch == 5): exit # training model.train() ibatch = 0 total_loss = 0 for batched_graph, labels, _, global_feats in train_loaders[0]: # # need to fix padded case # if is_padded: # tglobals.append(torch.zeros(1, len(global_feats[0]))) batch_start = time.time() logits = torch.tensor([]) tlabels = torch.tensor([]) weights = torch.tensor([]) batch_lengths = [] for cycler in train_cyclers: graph, label, track, global_feats = next(cycler) graph = graph.to(device) label = label.to(device) track = track.to(device) global_feats = global_feats.to(device) if is_padded: #Padding the globals to match padded graphs. global_feats = torch.concatenate((global_feats, torch.zeros(1, len(global_feats[0])).to(device))) load = time.time() if (args.profile): torch.cuda.nvtx.range_push("Model Forward") if (len(logits) == 0): logits = model(graph, global_feats) tlabels = label weights = track[:,1] else: logits = torch.concatenate((logits, model(graph, global_feats)), dim=0) tlabels = torch.concatenate((tlabels, label), dim=0) weights = torch.concatenate((weights, track[:,1]), dim=0) batch_lengths.append(logits.shape[0] - 1) if (args.profile): torch.cuda.nvtx.range_pop() # popping model forward if is_padded: keepmask = torch.full_like(logits[:,0], True, dtype=torch.bool) keepmask[batch_lengths] = False logits = logits[keepmask] tlabels = tlabels.to(torch.float) if logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss': logits = logits[:,0] tlabels = tlabels.to(torch.float) if loss_fcn.__class__.__name__ == 'CrossEntropyLoss': tlabels = tlabels.to(torch.long) # loss = loss_fcn(logits, tlabels.to(device)) # changed logits from logits[:,0] and left labels as int for multiclass. Does this break binary? Yes. # loss = torch.sum(weights * loss) / torch.sum(weights) if args.abs: weights = torch.abs(weights) loss = loss_fcn(logits, tlabels.to(device)) # Normalize loss within each label unique_labels = torch.unique(tlabels) # Get unique labels normalized_loss = 0.0 for label in unique_labels: # Mask for samples belonging to the current label label_mask = (tlabels == label) # Extract weights and losses for the current label label_weights = weights[label_mask] label_losses = loss[label_mask] # Compute normalized loss for the current label label_loss = torch.sum(label_weights * label_losses) / torch.sum(label_weights) # Add to the total normalized loss normalized_loss += label_loss loss = normalized_loss / len(unique_labels) if (args.profile): torch.cuda.nvtx.range_push("Model Backward") optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.detach().cpu().item() if (args.profile): torch.cuda.nvtx.range_pop() # pop model backward ibatch += 1 cumulative_times[0] += batch_start - run cumulative_times[1] += load - batch_start run = time.time() cumulative_times[2] += run - load if ibatch % 1000 == 0: print(f'Batch {ibatch} out of {len(train_loaders[0])}', end='\r') # gpu_mem() if (args.multigpu): print(f'Rank {rank} Epoch Done.') elif (args.multinode): print(f'Rank {args.global_rank} Epoch Done.') else: print("Epoch Done.") # validation scores = [] labels = [] weights = [] model.eval() if (args.profile): torch.cuda.nvtx.range_push("Model Evaluation") with torch.no_grad(): for loader in test_loaders: for batch, label, track, global_feats in loader: #Don't use compiled model for testing since we can't control the batch size. #We could before, but it assumes each dataset has the same number of batches... if is_padded: global_feats = torch.cat([global_feats, torch.zeros(1, len(global_feats[0]))]) if nocompile: batch_scores = model(batch.to(device), global_feats.to(device)) else: batch_scores = model._orig_mod(batch.to(device), global_feats.to(device)) if is_padded: scores.append(batch_scores[:-1,:]) else: scores.append(batch_scores) labels.append(label) weights.append(track[:,1]) eval_end = time.time() cumulative_times[3] += eval_end - run if (args.profile): torch.cuda.nvtx.range_pop() # pop evaluation if scores == []: #If validation set is empty. continue logits = torch.concatenate(scores).to(device) labels = torch.concatenate(labels).to(device) weights = torch.concatenate(weights).to(device) if (args.multigpu or args.multinode): gathered_logits = [torch.zeros_like(logits) for _ in range(dist.get_world_size())] gathered_labels = [torch.zeros_like(labels) for _ in range(dist.get_world_size())] gathered_weights = [torch.zeros_like(weights) for _ in range(dist.get_world_size())] if (args.multigpu or args.multinode): dist.barrier() if (args.multigpu and rank != 0) or (args.multinode and args.global_rank != 0): dist.gather(logits, dst=0) dist.gather(labels, dst=0) dist.gather(weights, dst=0) continue else: dist.gather(logits, gather_list=gathered_logits) dist.gather(labels, gather_list=gathered_labels) dist.gather(weights, gather_list=gathered_weights) logits = torch.concatenate(gathered_logits) labels = torch.concatenate(gathered_labels) weights = torch.concatenate(gathered_weights) wgt_mask = weights > 0 if args.abs: weights = torch.abs(weights) print(f"Num batches trained = {ibatch}") #Note: This section is a bit ugly. Very conditional. Should maybe config defined behavior? if (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"): scores = logits preds = scores accuracy = 0 test_auc = 0 acc = 0 contrastive_cluster_loss = finish_fn(logits) elif (loss_fcn.__class__.__name__ == "MultiLabelLoss"): scores = finish_fn(logits) preds = torch.round(scores) multilabel_accuracy = [] threshold = 0.1 # 10% threshold for i in range(len(labels[0])): # accurate_count = torch.sum(torch.abs(preds[:, i].to("cpu") - labels[:, i].to("cpu")) / labels[:, i].to("cpu") <= threshold) # multilabel_accruacy.append(accurate_count / len(labels)) multilabel_accuracy.append(torch.sum(preds[:, i].to("cpu") == labels[:, i].to("cpu")) / len(labels)) test_auc = 0 acc = np.mean(multilabel_accuracy) elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'BCEWithLogitsLoss': #Proxy for binary classification. test_auc = 0 acc = 0 logits = logits[:,0] scores = finish_fn(logits) labels =labels.to(torch.float) preds = scores > 0.5 test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), sample_weight=weights[wgt_mask].to("cpu")) acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels) elif logits.shape[1] == 1 and loss_fcn.__class__.__name__ == 'MSELoss': logits = logits[:,0] scores = finish_fn(logits) labels = labels.to(torch.float) acc = 0 test_auc = 0 else: preds = torch.argmax(logits, dim=1) scores = finish_fn(logits) if labels.dim() == 1: #Multi-class acc = torch.sum(preds.to("cpu") == labels.to("cpu")) / len(labels) #TODO: Make each class weighted equally? labels = labels.to("cpu") weights = weights.to("cpu") logits = logits.to("cpu") wgt_mask = wgt_mask.to("cpu") labels_onehot = np.zeros((len(labels), len(scores[0]))) labels_onehot[np.arange(len(labels)), labels] = 1 try: #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu")) if (len(scores[0]) != config["Model"]["args"]["out_size"]): print("ERROR: The out_size and the number of class labels don't match! Please check config.") test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu")) except ValueError: test_auc = np.nan else: #Multi-loss acc = torch.sum(preds.to("cpu") == labels[:,0].to("cpu")) / len(labels) try: test_auc = roc_auc_score(labels[:,0][wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu")) except ValueError: test_auc = np.nan # print(f"logits = {logits[:10]}") # print(f"preds = {preds[:2]}") # print(f"labels = {labels[:10]}") # print(f"len(Unique logits) = {len(torch.unique(logits))}") # print(f"Average of labels = {torch.mean(labels)}") # print(f"unique logits = {torch.unique(logits)[0]:.4f}, {torch.unique(logits)[-1]:.4f}") if (loss_fcn.__class__.__name__ == "MultiLabelLoss"): multilabel_log_str = "MultiLabel_Accuracy " for accuracy in multilabel_accuracy: multilabel_log_str += f" | {accuracy:.4f}" log.write(multilabel_log_str + '\n') print(multilabel_log_str, flush=True) elif (loss_fcn.__class__.__name__ == "ContrastiveClusterLoss"): contrastive_cluster_log_str = "ContrastiveClusterLoss " contrastive_cluster_log_str += f"Contrastive Loss: {contrastive_cluster_loss[0]:.4f}, Clustering Loss: {contrastive_cluster_loss[1]:.4f}, Variance Loss: {contrastive_cluster_loss[2]:.4f}" log.write(contrastive_cluster_log_str + '\n') print(contrastive_cluster_log_str, flush=True) # test_loss = loss_fcn(logits, labels.to(device)) # test_loss = loss_fcn(logits, labels) # test_loss = torch.sum(weights * test_loss) / torch.sum(weights) test_loss = loss_fcn(logits, labels) # Normalize loss within each label unique_labels = torch.unique(labels) # Get unique labels normalized_loss = 0.0 for label in unique_labels: # Mask for samples belonging to the current label label_mask = (labels == label) # Extract weights and losses for the current label label_weights = weights[label_mask] label_losses = test_loss[label_mask] # Compute normalized loss for the current label label_loss = torch.sum(label_weights * label_losses) / torch.sum(label_weights) # Add to the total normalized loss normalized_loss += label_loss test_loss = normalized_loss / len(unique_labels) end = time.time() log_str = "Epoch {:05d} | LR {:.4e} | Loss {:.4f} | Accuracy {:.4f} | Test_Loss {:.4f} | Test_AUC {:.4f} | Time {:.4f} s".format( epoch, optimizer.param_groups[0]['lr'], total_loss/ibatch, acc, test_loss, test_auc, end - start ) log.write(log_str + '\n') print(log_str, flush=True) state_dict = model.state_dict() if not nocompile: state_dict = model._orig_mod.state_dict() new_state_dict = {} for k, v in state_dict.items(): new_key = k.replace('module.', '') new_state_dict[new_key] = v state_dict = new_state_dict # print('Testing done') # gpu_mem() if epoch == 2: # torch.cuda.cudart().cudaProfilerStop() pass torch.save({ 'epoch': epoch, 'model_state_dict': state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'early_stop': early_termination.to_dict() }, os.path.join(config['Training_Directory'], f"model_epoch_{epoch}.pt")) np.savez(os.path.join(config['Training_Directory'], f'model_epoch_{epoch}.npz'), scores=scores.to("cpu"), labels=labels.to("cpu")) save_end = time.time() cumulative_times[4] += save_end - eval_end early_termination.update(test_loss) if early_termination.should_stop: log_str = f"Early Termination at Epoch {epoch}" log.write(log_str + "\n") print(log_str) log_str = early_termination.to_str() log.write(log_str + "\n") print(log_str) break if (custom_scheduler): custom_scheduler.step(model, {'test_auc':test_auc}) scheduler.step() if (args.profile): torch.cuda.nvtx.range_pop() # pop epoch print(f"Load: {cumulative_times[0]:.4f} s") print(f"Batch: {cumulative_times[1]:.4f} s") print(f"Train: {cumulative_times[2]:.4f} s") print(f"Eval: {cumulative_times[3]:.4f} s") print(f"Save: {cumulative_times[4]:.4f} s") log.close() def find_free_port(): import socket from contextlib import closing with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return str(s.getsockname()[1]) def init_process_group(world_size, rank, port): os.environ['MASTER_ADDR'] = 'localhost' # os.environ['MASTER_PORT'] = find_free_port() os.environ['MASTER_PORT'] = port dist.init_process_group( backend="nccl", # change to 'nccl' for multiple GPUs (other was gloo) init_method='env://', world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=300), ) def main(rank=0, args=None, world_size=1, port=24500, seed=12345): #Prevent simultaneous file access #sleep_time = 120 * rank #time.sleep(sleep_time) #Load config file config = utils.load_config(args.config) if (args.directory): print(f"New training directory: { config['Training_Directory'] + args.directory}") config['Training_Directory'] = config['Training_Directory'] + args.directory if not os.path.exists(config['Training_Directory']): os.makedirs(config['Training_Directory'], exist_ok=True) with open(config['Training_Directory'] + '/config.yaml', 'w') as f: yaml.dump(config, f) batch_size = config["Training"]["batch_size"] if(args.plot): rl = utils.read_log(config) utils.plot_log(rl, config['Training_Directory'] + '/training.png') print('Log at ' + config['Training_Directory'] + '/training.log') print('Plotted at ' + config['Training_Directory'] + '/training.png') exit() if (args.multigpu): print(f"Setting up multigpu") start_time = time.time() init_process_group(world_size, rank, port) print("multigpu setup time: {:.4f} s".format(time.time() - start_time)) device = torch.device(f'cuda:{rank}') torch.cuda.device(device) elif (args.multinode): device = torch.device(f'cuda:{rank}') torch.cuda.device(device) print(f"global rank = {args.global_rank}, local rank = {rank}, device = {device}") else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if (args.cpu): print(f"Using CPU") device = "cpu" train_loaders = [] test_loaders = [] val_loaders = [] load_start = time.time() torch.backends.cuda.matmul.allow_tf32 = True ldr_type = datasets.LazyPreBatchedDataset if args.lazy else datasets.PreBatchedDataset #Load datasets if (pargs.statistics): pargs.statistics = int(pargs.statistics) print(f"Training Dataset Size: {pargs.statistics}") num_batches = int(np.ceil(pargs.statistics / batch_size)) np.random.seed(pargs.seed) for dset_conf in config["Datasets"]: dset = utils.buildFromConfig(config["Datasets"][dset_conf]) if 'batch_size' in config["Datasets"][dset_conf]: batch_size = config["Datasets"][dset_conf]['batch_size'] fold_conf = config["Datasets"][dset_conf]["folding"] shuffle_chunks = config["Datasets"][dset_conf].get("shuffle_chunks", 10) padding_mode = config["Datasets"][dset_conf].get("padding_mode", "STEPS") mask_fn = utils.fold_selection(fold_conf, "train") if args.preshuffle: # ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, use_ddp = args.multigpu, rank=rank, world_size=world_size) ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'train'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size = config["Model"]["args"]["hid_size"]) gsamp, _, _, global_samp = ldr[0] sampler = None if (pargs.statistics): sampler = np.random.choice(range(len(ldr)), size=num_batches) if (args.multigpu): sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True) # num_batches = len(ldr) # sampler = list(sampler) # if (sampler[0] >= num_batches % world_size): # sampler.pop() if (args.multinode): sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True) train_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler = sampler)) sampler = None ldr = ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=mask_fn, suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, padding_mode = padding_mode, hidden_size= config['Model']['args']['hid_size']) if (args.multigpu): sampler = DistributedSampler(ldr, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True) # num_batches = len(ldr) # sampler = list(sampler) # if (rank >= num_batches % world_size): # sampler.pop() if (args.multinode): sampler = DistributedSampler(ldr, num_replicas=world_size, rank=pargs.global_rank, shuffle=False, drop_last=True) test_loaders.append(torch.utils.data.DataLoader(ldr, batch_size = None, num_workers = 0, sampler=sampler)) if "validation" in fold_conf: val_loaders.append(torch.utils.data.DataLoader((ldr_type(start_dataset=dset, batch_size=batch_size, mask_fn=utils.fold_selection(fold_conf, "validation"), suffix = utils.fold_selection_name(fold_conf, 'validation'), chunks = shuffle_chunks, hidden_size=config['Model']['args']['hid_size'], padding_mode = padding_mode, rank=rank, world_size=1)), batch_size = None, num_workers = 0, sampler = sampler)) else: print("No validation set for dataset ", dset_conf) else: train_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "train"))) gsamp, _, _, global_samp = dset[0] test_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "test"))) if "validation" in fold_conf: val_loaders.append(datasets.GetBatchedLoader(dset, batch_size, utils.fold_selection(fold_conf, "validation"))) else: print("No validation set for dataset ", dset_conf) load_end = time.time() print("Load time: {:.4f} s".format(load_end - load_start)) model = utils.buildFromConfig(config["Model"], {'sample_graph': gsamp, 'sample_global': global_samp, 'seed': seed}).to(device) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Number of trainable parameters = {pytorch_total_params}") if not args.nocompile: model = torch.compile(model) if args.multigpu: print(f"Trying to create DDP model") start_time = time.time() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) print("model creation time: {:.4f} s".format(time.time() - start_time)) if (args.multinode): print(f"Trying to create DDP model") start_time = time.time() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) print("model creation time: {:.4f} s".format(time.time() - start_time)) # total_params = 0 # for param_dict in model.parameters(): # for param in param_dict['params']: # if param.requires_grad: # total_params += param.numel() # print(f"Number of trainable parameters = {total_params}") if(type(model) == GCN.Clustering): print("clustering") if args.evaluate != None: evaluate(test_loaders, model, config, device, args.evaluate) exit() # model training print("Training...") gpu_mem() train(train_loaders, test_loaders, model, device, config, args, rank) # test the model # print("Testing...") # evaluate(val_loaders, model, config, device) # if args.multigpu or args.multinode: # dist.destroy_process_group() # if rank == 0: # rl = utils.read_log(config) # utils.plot_log(rl, config['Training_Directory'] + '/training.png') # print('Log at ' + config['Training_Directory'] + '/training.log') # print('Plotted at ' + config['Training_Directory'] + '/training.png') if __name__ == "__main__": #Handle CLI arguments parser = argparse.ArgumentParser() add_arg = parser.add_argument add_arg("--config", type=str, help="Config file.", required=True) add_arg("--restart", action="store_true", help="Restart training from scratch.") add_arg("--preshuffle", action="store_true", help="Shuffle data before training.") add_arg("--lazy", action="store_true", help="Lazy loading of data.") add_arg("--nocompile", action="store_true", help="Disable JIT compilation.") add_arg("--evaluate", type = int, help="Skip training and go to evaluation.") add_arg("--plot", action="store_true", help="Plot training logs.") add_arg("--multigpu", action="store_true", help="Use multiple GPUs.") add_arg("--multinode", action="store_true", help="Use multiple nodes.") add_arg("--savecache", action="store_true", help="") add_arg("--cpu", action="store_true", help="Uses the cpu only") add_arg("--statistics", type=float, help="Size of training data") add_arg("--directory", type=str, help="Append to Training Directory") add_arg("--seed", type=int, default=2, help="Sets random seed") add_arg("--abs", action="store_true", help="Use abs value of per-event weight") add_arg("--profile", action="store_true", help="use nsight systems profiler") pargs = parser.parse_args() if pargs.multigpu: port = find_free_port() torch.backends.cudnn.enabled = False mp.spawn(main, args=(pargs, 4, port), nprocs=4, join=True) if pargs.multinode: global_rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) print(f"global_rank = {global_rank}, local_rank = {local_rank}, world_size = {world_size}") dist.init_process_group(backend="nccl") torch.backends.cudnn.enabled = False pargs.global_rank = global_rank main(rank = local_rank, args=pargs, world_size=world_size) else: main(0, pargs)