| import sys |
| import glob |
| import os |
| import shutil |
| import time |
| import torch |
| import torch.utils.data |
| from collections import OrderedDict |
|
|
| if sys.version_info >= (3, 10): |
| from collections.abc import Sequence |
| else: |
| from collections import Sequence |
| from pointcept.utils.timer import Timer |
| from pointcept.utils.comm import is_main_process, synchronize, get_world_size |
| from pointcept.utils.cache import shared_dict |
|
|
| import pointcept.utils.comm as comm |
| |
|
|
| from .default import HookBase |
| from .builder import HOOKS |
|
|
|
|
| @HOOKS.register_module() |
| class IterationTimer(HookBase): |
| def __init__(self, warmup_iter=1): |
| self._warmup_iter = warmup_iter |
| self._start_time = time.perf_counter() |
| self._iter_timer = Timer() |
| self._remain_iter = 0 |
|
|
| def before_train(self): |
| self._start_time = time.perf_counter() |
| self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader) |
|
|
| def before_epoch(self): |
| self._iter_timer.reset() |
|
|
| def before_step(self): |
| data_time = self._iter_timer.seconds() |
| self.trainer.storage.put_scalar("data_time", data_time) |
|
|
| def after_step(self): |
| batch_time = self._iter_timer.seconds() |
| self._iter_timer.reset() |
| self.trainer.storage.put_scalar("batch_time", batch_time) |
| self._remain_iter -= 1 |
| remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg |
| t_m, t_s = divmod(remain_time, 60) |
| t_h, t_m = divmod(t_m, 60) |
| remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s)) |
| if "iter_info" in self.trainer.comm_info.keys(): |
| info = ( |
| "Data {data_time_val:.3f} ({data_time_avg:.3f}) " |
| "Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) " |
| "Remain {remain_time} ".format( |
| data_time_val=self.trainer.storage.history("data_time").val, |
| data_time_avg=self.trainer.storage.history("data_time").avg, |
| batch_time_val=self.trainer.storage.history("batch_time").val, |
| batch_time_avg=self.trainer.storage.history("batch_time").avg, |
| remain_time=remain_time, |
| ) |
| ) |
| self.trainer.comm_info["iter_info"] += info |
| if self.trainer.comm_info["iter"] <= self._warmup_iter: |
| self.trainer.storage.history("data_time").reset() |
| self.trainer.storage.history("batch_time").reset() |
|
|
|
|
| @HOOKS.register_module() |
| class InformationWriter(HookBase): |
| def __init__(self): |
| self.curr_iter = 0 |
| self.model_output_keys = [] |
|
|
| def before_train(self): |
| self.trainer.comm_info["iter_info"] = "" |
| self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader) |
|
|
| def before_step(self): |
| self.curr_iter += 1 |
| |
| |
| |
| |
| |
| |
| |
| |
| info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format( |
| epoch=self.trainer.epoch + 1, |
| max_epoch=self.trainer.max_epoch, |
| iter=self.trainer.comm_info["iter"] + 1, |
| max_iter=len(self.trainer.train_loader), |
| ) |
| self.trainer.comm_info["iter_info"] += info |
|
|
| def after_step(self): |
| if "model_output_dict" in self.trainer.comm_info.keys(): |
| model_output_dict = self.trainer.comm_info["model_output_dict"] |
| self.model_output_keys = model_output_dict.keys() |
| for key in self.model_output_keys: |
| self.trainer.storage.put_scalar(key, model_output_dict[key].item()) |
|
|
| for key in self.model_output_keys: |
| self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format( |
| key=key, value=self.trainer.storage.history(key).val |
| ) |
| lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"] |
| self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr) |
| self.trainer.logger.info(self.trainer.comm_info["iter_info"]) |
| self.trainer.comm_info["iter_info"] = "" |
| if self.trainer.writer is not None: |
| self.trainer.writer.add_scalar("lr", lr, self.curr_iter) |
| for key in self.model_output_keys: |
| self.trainer.writer.add_scalar( |
| "train_batch/" + key, |
| self.trainer.storage.history(key).val, |
| self.curr_iter, |
| ) |
|
|
| def after_epoch(self): |
| epoch_info = "Train result: " |
| for key in self.model_output_keys: |
| epoch_info += "{key}: {value:.4f} ".format( |
| key=key, value=self.trainer.storage.history(key).avg |
| ) |
| self.trainer.logger.info(epoch_info) |
| if self.trainer.writer is not None: |
| for key in self.model_output_keys: |
| self.trainer.writer.add_scalar( |
| "train/" + key, |
| self.trainer.storage.history(key).avg, |
| self.trainer.epoch + 1, |
| ) |
|
|
|
|
| @HOOKS.register_module() |
| class CheckpointSaver(HookBase): |
| def __init__(self, save_freq=None): |
| self.save_freq = save_freq |
|
|
| def after_epoch(self): |
| if is_main_process(): |
| is_best = False |
| if self.trainer.cfg.evaluate: |
| current_metric_value = self.trainer.comm_info["current_metric_value"] |
| current_metric_name = self.trainer.comm_info["current_metric_name"] |
| if current_metric_value > self.trainer.best_metric_value: |
| self.trainer.best_metric_value = current_metric_value |
| is_best = True |
| self.trainer.logger.info( |
| "Best validation {} updated to: {:.4f}".format( |
| current_metric_name, current_metric_value |
| ) |
| ) |
| self.trainer.logger.info( |
| "Currently Best {}: {:.4f}".format( |
| current_metric_name, self.trainer.best_metric_value |
| ) |
| ) |
|
|
| filename = os.path.join( |
| self.trainer.cfg.save_path, "model", "model_last.pth" |
| ) |
| self.trainer.logger.info("Saving checkpoint to: " + filename) |
| torch.save( |
| { |
| "epoch": self.trainer.epoch + 1, |
| "state_dict": self.trainer.model.state_dict(), |
| "optimizer": self.trainer.optimizer.state_dict(), |
| "scheduler": self.trainer.scheduler.state_dict(), |
| "scaler": self.trainer.scaler.state_dict() |
| if self.trainer.cfg.enable_amp |
| else None, |
| "best_metric_value": self.trainer.best_metric_value, |
| }, |
| filename + ".tmp", |
| ) |
| os.replace(filename + ".tmp", filename) |
| if is_best: |
| shutil.copyfile( |
| filename, |
| os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"), |
| ) |
| if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0: |
| shutil.copyfile( |
| filename, |
| os.path.join( |
| self.trainer.cfg.save_path, |
| "model", |
| f"epoch_{self.trainer.epoch + 1}.pth", |
| ), |
| ) |
|
|
|
|
| @HOOKS.register_module() |
| class CheckpointLoader(HookBase): |
| def __init__(self, keywords="", replacement=None, strict=False): |
| self.keywords = keywords |
| self.replacement = replacement if replacement is not None else keywords |
| self.strict = strict |
|
|
| def before_train(self): |
| self.trainer.logger.info("=> Loading checkpoint & weight ...") |
| if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight): |
| self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}") |
| checkpoint = torch.load( |
| self.trainer.cfg.weight, |
| map_location=lambda storage, loc: storage.cuda(), |
| ) |
| self.trainer.logger.info( |
| f"Loading layer weights with keyword: {self.keywords}, " |
| f"replace keyword with: {self.replacement}" |
| ) |
| weight = OrderedDict() |
| for key, value in checkpoint["state_dict"].items(): |
| if not key.startswith("module."): |
| if comm.get_world_size() > 1: |
| key = "module." + key |
| |
| if self.keywords in key: |
| key = key.replace(self.keywords, self.replacement) |
| if comm.get_world_size() == 1: |
| key = key[7:] |
| weight[key] = value |
| load_state_info = self.trainer.model.load_state_dict( |
| weight, strict=self.strict |
| ) |
| self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") |
| if self.trainer.cfg.resume: |
| self.trainer.logger.info( |
| f"Resuming train at eval epoch: {checkpoint['epoch']}" |
| ) |
| self.trainer.start_epoch = checkpoint["epoch"] |
| self.trainer.best_metric_value = checkpoint["best_metric_value"] |
| self.trainer.optimizer.load_state_dict(checkpoint["optimizer"]) |
| self.trainer.scheduler.load_state_dict(checkpoint["scheduler"]) |
| if self.trainer.cfg.enable_amp: |
| self.trainer.scaler.load_state_dict(checkpoint["scaler"]) |
| else: |
| self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}") |
|
|
|
|
| @HOOKS.register_module() |
| class DataCacheOperator(HookBase): |
| def __init__(self, data_root, split): |
| self.data_root = data_root |
| self.split = split |
| self.data_list = self.get_data_list() |
|
|
| def get_data_list(self): |
| if isinstance(self.split, str): |
| data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth")) |
| elif isinstance(self.split, Sequence): |
| data_list = [] |
| for split in self.split: |
| data_list += glob.glob(os.path.join(self.data_root, split, "*.pth")) |
| else: |
| raise NotImplementedError |
| return data_list |
|
|
| def get_cache_name(self, data_path): |
| data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0] |
| return "pointcept" + data_name.replace(os.path.sep, "-") |
|
|
| def before_train(self): |
| self.trainer.logger.info( |
| f"=> Caching dataset: {self.data_root}, split: {self.split} ..." |
| ) |
| if is_main_process(): |
| for data_path in self.data_list: |
| cache_name = self.get_cache_name(data_path) |
| data = torch.load(data_path) |
| shared_dict(cache_name, data) |
| synchronize() |
|
|
|
|
| @HOOKS.register_module() |
| class RuntimeProfiler(HookBase): |
| def __init__( |
| self, |
| forward=True, |
| backward=True, |
| interrupt=False, |
| warm_up=2, |
| sort_by="cuda_time_total", |
| row_limit=30, |
| ): |
| self.forward = forward |
| self.backward = backward |
| self.interrupt = interrupt |
| self.warm_up = warm_up |
| self.sort_by = sort_by |
| self.row_limit = row_limit |
|
|
| def before_train(self): |
| self.trainer.logger.info("Profiling runtime ...") |
| from torch.profiler import profile, record_function, ProfilerActivity |
|
|
| for i, input_dict in enumerate(self.trainer.train_loader): |
| if i == self.warm_up + 1: |
| break |
| for key in input_dict.keys(): |
| if isinstance(input_dict[key], torch.Tensor): |
| input_dict[key] = input_dict[key].cuda(non_blocking=True) |
| if self.forward: |
| with profile( |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], |
| record_shapes=True, |
| profile_memory=True, |
| with_stack=True, |
| ) as forward_prof: |
| with record_function("model_inference"): |
| output_dict = self.trainer.model(input_dict) |
| else: |
| output_dict = self.trainer.model(input_dict) |
| loss = output_dict["loss"] |
| if self.backward: |
| with profile( |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], |
| record_shapes=True, |
| profile_memory=True, |
| with_stack=True, |
| ) as backward_prof: |
| with record_function("model_inference"): |
| loss.backward() |
| self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]") |
| if self.forward: |
| self.trainer.logger.info( |
| "Forward profile: \n" |
| + str( |
| forward_prof.key_averages().table( |
| sort_by=self.sort_by, row_limit=self.row_limit |
| ) |
| ) |
| ) |
| forward_prof.export_chrome_trace( |
| os.path.join(self.trainer.cfg.save_path, "forward_trace.json") |
| ) |
|
|
| if self.backward: |
| self.trainer.logger.info( |
| "Backward profile: \n" |
| + str( |
| backward_prof.key_averages().table( |
| sort_by=self.sort_by, row_limit=self.row_limit |
| ) |
| ) |
| ) |
| backward_prof.export_chrome_trace( |
| os.path.join(self.trainer.cfg.save_path, "backward_trace.json") |
| ) |
| if self.interrupt: |
| sys.exit(0) |
|
|
|
|
| @HOOKS.register_module() |
| class RuntimeProfilerV2(HookBase): |
| def __init__( |
| self, |
| interrupt=False, |
| wait=1, |
| warmup=1, |
| active=10, |
| repeat=1, |
| sort_by="cuda_time_total", |
| row_limit=30, |
| ): |
| self.interrupt = interrupt |
| self.wait = wait |
| self.warmup = warmup |
| self.active = active |
| self.repeat = repeat |
| self.sort_by = sort_by |
| self.row_limit = row_limit |
|
|
| def before_train(self): |
| self.trainer.logger.info("Profiling runtime ...") |
| from torch.profiler import ( |
| profile, |
| record_function, |
| ProfilerActivity, |
| schedule, |
| tensorboard_trace_handler, |
| ) |
|
|
| prof = profile( |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], |
| schedule=schedule( |
| wait=self.wait, |
| warmup=self.warmup, |
| active=self.active, |
| repeat=self.repeat, |
| ), |
| on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path), |
| record_shapes=True, |
| profile_memory=True, |
| with_stack=True, |
| ) |
| prof.start() |
| for i, input_dict in enumerate(self.trainer.train_loader): |
| if i >= (self.wait + self.warmup + self.active) * self.repeat: |
| break |
| for key in input_dict.keys(): |
| if isinstance(input_dict[key], torch.Tensor): |
| input_dict[key] = input_dict[key].cuda(non_blocking=True) |
| with record_function("model_forward"): |
| output_dict = self.trainer.model(input_dict) |
| loss = output_dict["loss"] |
| with record_function("model_backward"): |
| loss.backward() |
| prof.step() |
| self.trainer.logger.info( |
| f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]" |
| ) |
| self.trainer.logger.info( |
| "Profile: \n" |
| + str( |
| prof.key_averages().table( |
| sort_by=self.sort_by, row_limit=self.row_limit |
| ) |
| ) |
| ) |
| prof.stop() |
|
|
| if self.interrupt: |
| sys.exit(0) |
|
|