Spaces:
Build error
Build error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import os | |
| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| import random | |
| import numpy as np | |
| from pkg_resources import parse_version | |
| from model.third_party.HMNet.Models.Trainers.BaseTrainer import BaseTrainer | |
| from model.third_party.HMNet.Utils.GeneralUtils import bcolors | |
| from model.third_party.HMNet.Utils.distributed import distributed | |
| class DistributedTrainer(BaseTrainer): | |
| def __init__(self, opt): | |
| super().__init__(opt) | |
| self.seed = int(self.opt["SEED"]) if "SEED" in self.opt else 0 | |
| random.seed(self.seed) | |
| np.random.seed(self.seed) | |
| torch.manual_seed(self.seed) | |
| ( | |
| self.opt["device"], | |
| _, | |
| self.opt["world_size"], | |
| self.opt["local_size"], | |
| self.opt["rank"], | |
| self.opt["local_rank"], | |
| _, | |
| self.opt["run"], | |
| ) = distributed(opt, not self.use_cuda) | |
| self.getSaveFolder() | |
| self.opt["logFile"] = f"log_{self.opt['rank']}.txt" | |
| self.saveConf() | |
| self.high_pytorch_version = parse_version(torch.__version__) >= parse_version( | |
| "1.2.0" | |
| ) | |
| if self.opt["rank"] == 0: | |
| print( | |
| bcolors.OKGREEN, | |
| torch.__version__, | |
| bcolors.ENDC, | |
| "is", | |
| "high" if self.high_pytorch_version else "low", | |
| ) | |
| if self.use_cuda: | |
| # torch.cuda.manual_seed_all(self.seed) | |
| # ddp: only set seed on GPU associated with this process | |
| torch.cuda.manual_seed(self.seed) | |
| # ddp: print stats and update learning rate | |
| if self.opt["rank"] == 0: | |
| print( | |
| "Number of GPUs is", | |
| bcolors.OKGREEN, | |
| self.opt["world_size"], | |
| bcolors.ENDC, | |
| ) | |
| # print('Boost learning rate from', bcolors.OKGREEN, self.opt['START_LEARNING_RATE'], bcolors.ENDC, 'to', | |
| # bcolors.OKGREEN, self.opt['START_LEARNING_RATE'] * self.opt['world_size'], bcolors.ENDC) | |
| print( | |
| "Effective batch size is increased from", | |
| bcolors.OKGREEN, | |
| self.opt["MINI_BATCH"], | |
| bcolors.ENDC, | |
| "to", | |
| bcolors.OKGREEN, | |
| self.opt["MINI_BATCH"] * self.opt["world_size"], | |
| bcolors.ENDC, | |
| ) | |
| self.grad_acc_steps = 1 | |
| if "GRADIENT_ACCUMULATE_STEP" in self.opt: | |
| if self.opt["rank"] == 0: | |
| print( | |
| "Gradient accumulation steps =", | |
| bcolors.OKGREEN, | |
| self.opt["GRADIENT_ACCUMULATE_STEP"], | |
| bcolors.ENDC, | |
| ) | |
| # print('Boost learning rate from', bcolors.OKGREEN, self.opt['START_LEARNING_RATE'], bcolors.ENDC, 'to', | |
| # bcolors.OKGREEN, self.opt['START_LEARNING_RATE'] * self.opt['world_size'] * self.opt['GRADIENT_ACCUMULATE_STEP'], bcolors.ENDC) | |
| print( | |
| "Effective batch size =", | |
| bcolors.OKGREEN, | |
| self.opt["MINI_BATCH"] | |
| * self.opt["world_size"] | |
| * self.opt["GRADIENT_ACCUMULATE_STEP"], | |
| bcolors.ENDC, | |
| ) | |
| self.grad_acc_steps = int(self.opt["GRADIENT_ACCUMULATE_STEP"]) | |
| # self.opt['START_LEARNING_RATE'] *= self.opt['world_size'] * self.grad_acc_steps | |
| def tb_log_scalar(self, name, value, step): | |
| if self.opt["rank"] == 0: | |
| if self.tb_writer is None: | |
| self.tb_writer = SummaryWriter( | |
| os.path.join(self.saveFolder, "tensorboard") | |
| ) | |
| self.tb_writer.add_scalar(name, value, step) | |
| def log(self, s): | |
| # When 'OFFICIAL' flag is set in the config file, the program does not output logs | |
| if self.is_official: | |
| return | |
| try: | |
| if self.logFileHandle is None: | |
| self.logFileHandle = open( | |
| os.path.join(self.saveFolder, self.opt["logFile"]), "a" | |
| ) | |
| self.logFileHandle.write(s + "\n") | |
| except Exception as e: | |
| print("ERROR while writing log file:", e) | |
| print(s) | |
| def getSaveFolder(self): | |
| runid = 1 | |
| while True: | |
| saveFolder = os.path.join( | |
| self.opt["datadir"], | |
| self.opt["basename"] + "_conf~", | |
| "run_" + str(runid), | |
| ) | |
| if not os.path.isdir(saveFolder): | |
| if self.opt["world_size"] > 1: | |
| torch.distributed.barrier() | |
| if self.opt["rank"] == 0: | |
| os.makedirs(saveFolder) | |
| self.saveFolder = saveFolder | |
| if self.opt["world_size"] > 1: | |
| torch.distributed.barrier() | |
| print( | |
| "Saving logs, model, checkpoint, and evaluation in " | |
| + self.saveFolder | |
| ) | |
| return | |
| runid = runid + 1 | |
| def saveConf(self): | |
| if self.opt["rank"] == 0: | |
| super().saveConf() | |