""" Trainer Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import os import sys import weakref try: import wandb except Exception: wandb = None import torch import torch.nn as nn import torch.utils.data from packaging import version from functools import partial from pathlib import Path if sys.version_info >= (3, 10): from collections.abc import Iterator else: from collections import Iterator from tensorboardX import SummaryWriter from .defaults import create_ddp_model, worker_init_fn from .hooks import HookBase, build_hooks import pointcept.utils.comm as comm from pointcept.datasets import build_dataset, point_collate_fn, collate_fn from pointcept.models import build_model from pointcept.utils.logger import get_root_logger from pointcept.utils.optimizer import build_optimizer from pointcept.utils.scheduler import build_scheduler from pointcept.utils.events import EventStorage, ExceptionWriter from pointcept.utils.registry import Registry TRAINERS = Registry("trainers") AMP_DTYPE = dict( float16=torch.float16, bfloat16=torch.bfloat16, ) class TrainerBase: def __init__(self) -> None: self.hooks = [] self.model = None self.epoch = 0 self.start_epoch = 0 self.max_epoch = 0 self.max_iter = 0 self.comm_info = dict() self.data_iterator: Iterator = enumerate([]) self.storage: EventStorage self.writer: SummaryWriter def register_hooks(self, hooks) -> None: hooks = build_hooks(hooks) for h in hooks: assert isinstance(h, HookBase) # To avoid circular reference, hooks and trainer cannot own each other. # This normally does not matter, but will cause memory leak if the # involved objects contain __del__: # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ h.trainer = weakref.proxy(self) self.hooks.extend(hooks) def train(self): with EventStorage() as self.storage: # => before train self.before_train() for self.epoch in range(self.start_epoch, self.max_epoch): # => before epoch self.before_epoch() # => run_epoch for ( self.comm_info["iter"], self.comm_info["input_dict"], ) in self.data_iterator: # => before_step self.before_step() # => run_step self.run_step() # => after_step self.after_step() # => after epoch self.after_epoch() # => after train self.after_train() def before_train(self): for h in self.hooks: h.before_train() def before_epoch(self): for h in self.hooks: h.before_epoch() def before_step(self): for h in self.hooks: h.before_step() def run_step(self): raise NotImplementedError def after_step(self): for h in self.hooks: h.after_step() def after_epoch(self): for h in self.hooks: h.after_epoch() self.storage.reset_histories() def after_train(self): # Sync GPU before running train hooks comm.synchronize() for h in self.hooks: h.after_train() if comm.is_main_process(): self.writer.close() @TRAINERS.register_module("DefaultTrainer") class Trainer(TrainerBase): def __init__(self, cfg): super(Trainer, self).__init__() self.epoch = 0 self.start_epoch = 0 self.max_epoch = cfg.epoch # 修改为 cfg.epoch self.best_metric_value = -torch.inf self.logger = get_root_logger( log_file=os.path.join(cfg.save_path, "train.log"), file_mode="a" if cfg.resume else "w", ) self.logger.info("=> Loading config ...") self.cfg = cfg self.logger.info(f"Save path: {cfg.save_path}") self.logger.info(f"Config:\n{cfg.pretty_text}") self.logger.info("=> Building model ...") self.model = self.build_model() self.logger.info("=> Building writer ...") self.writer = self.build_writer() self.logger.info("=> Building train dataset & dataloader ...") self.train_loader = self.build_train_loader() self.logger.info("=> Building val dataset & dataloader ...") self.val_loader = self.build_val_loader() self.logger.info("=> Building optimize, scheduler, scaler(amp) ...") self.optimizer = self.build_optimizer() self.scheduler = self.build_scheduler() self.scaler = self.build_scaler() self.logger.info("=> Building hooks ...") self.register_hooks(self.cfg.hooks) def train(self): with EventStorage() as self.storage, ExceptionWriter(): # => before train self.before_train() self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>") for self.epoch in range(self.start_epoch, self.max_epoch): # => before epoch if comm.get_world_size() > 1: self.train_loader.sampler.set_epoch(self.epoch) self.model.train() self.data_iterator = enumerate(self.train_loader) self.before_epoch() # => run_epoch for ( self.comm_info["iter"], self.comm_info["input_dict"], ) in self.data_iterator: # => before_step self.before_step() # => run_step self.run_step() # => after_step self.after_step() # => after epoch self.after_epoch() def run_step(self): if version.parse(torch.__version__) >= version.parse("2.4"): auto_cast = partial(torch.amp.autocast, device_type="cuda") else: # deprecated warning auto_cast = torch.cuda.amp.autocast input_dict = self.comm_info["input_dict"] for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].cuda(non_blocking=True) with auto_cast( enabled=self.cfg.enable_amp, dtype=AMP_DTYPE[self.cfg.amp_dtype] ): output_dict = self.model(input_dict) loss = output_dict["loss"] self.optimizer.zero_grad() if self.cfg.enable_amp: self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) if self.cfg.clip_grad is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.clip_grad ) self.scaler.step(self.optimizer) # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large. # Fix torch warning scheduler step before optimizer step. scaler = self.scaler.get_scale() self.scaler.update() if scaler <= self.scaler.get_scale(): self.scheduler.step() else: loss.backward() if self.cfg.clip_grad is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.clip_grad ) self.optimizer.step() self.scheduler.step() if self.cfg.empty_cache: torch.cuda.empty_cache() self.comm_info["model_output_dict"] = output_dict def after_epoch(self): for h in self.hooks: h.after_epoch() self.storage.reset_histories() if self.cfg.empty_cache_per_epoch: torch.cuda.empty_cache() def build_model(self): model = build_model(self.cfg.model) if self.cfg.get("quantize", False): self.logger.info("Quantization flag detected. Converting model to Bi-PTV3 before DDP.") from pointcept.models.quantization.quant_utils import convert_ptv3_to_bi_ptv3 model = convert_ptv3_to_bi_ptv3(model, verbose=comm.is_main_process()) # === QAT 0920 begin: minimal hook === try: from pointcept.utils.quant_0920 import install_qat_from_cfg_or_env_0920 model = install_qat_from_cfg_or_env_0920(model, self.cfg) except Exception as e: print(f"[QAT-0920] attach failed: {e}") # === QAT 0920 end === if self.cfg.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) # logger.info(f"Model: \n{self.model}") self.logger.info(f"Num params: {n_parameters}") model = create_ddp_model( model.cuda(), broadcast_buffers=False, find_unused_parameters=self.cfg.find_unused_parameters, ) return model def build_writer(self): writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") if self.cfg.enable_wandb and comm.is_main_process(): tag, name = Path(self.cfg.save_path).parts[-2:] wandb.init( project=self.cfg.wandb_project, name=f"{tag}/{name}", tags=[tag], dir=self.cfg.save_path, settings=wandb.Settings(api_key=self.cfg.wandb_key), config=self.cfg, ) return writer def build_train_loader(self): train_data = build_dataset(self.cfg.data.train) if comm.get_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) else: train_sampler = None init_fn = ( partial( worker_init_fn, num_workers=self.cfg.num_worker_per_gpu, rank=comm.get_rank(), seed=self.cfg.seed, ) if self.cfg.seed is not None else None ) train_loader = torch.utils.data.DataLoader( train_data, batch_size=self.cfg.batch_size_per_gpu, shuffle=(train_sampler is None), num_workers=self.cfg.num_worker_per_gpu, sampler=train_sampler, collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob), pin_memory=True, worker_init_fn=init_fn, drop_last=len(train_data) > self.cfg.batch_size, persistent_workers=False, ) return train_loader def build_val_loader(self): val_loader = None if self.cfg.evaluate: val_data = build_dataset(self.cfg.data.val) if comm.get_world_size() > 1: val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader( val_data, batch_size=self.cfg.batch_size_val_per_gpu, shuffle=False, num_workers=self.cfg.num_worker_per_gpu, pin_memory=True, sampler=val_sampler, collate_fn=collate_fn, ) return val_loader def build_optimizer(self): return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts) def build_scheduler(self): assert hasattr(self, "optimizer") assert hasattr(self, "train_loader") self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.epoch # 修改为 self.cfg.epoch return build_scheduler(self.cfg.scheduler, self.optimizer) def build_scaler(self): if version.parse(torch.__version__) >= version.parse("2.4"): grad_scaler = partial(torch.amp.GradScaler, device="cuda") else: # deprecated warning grad_scaler = torch.cuda.amp.GradScaler scaler = grad_scaler() if self.cfg.enable_amp else None return scaler @TRAINERS.register_module("MultiDatasetTrainer") class MultiDatasetTrainer(Trainer): def build_train_loader(self): from pointcept.datasets import MultiDatasetDataloader train_data = build_dataset(self.cfg.data.train) train_loader = MultiDatasetDataloader( train_data, self.cfg.batch_size_per_gpu, self.cfg.num_worker_per_gpu, self.cfg.mix_prob, self.cfg.seed, ) self.comm_info["iter_per_epoch"] = len(train_loader) return train_loader