| import os |
| import time |
| import logging |
|
|
| mainlogger = logging.getLogger("mainlogger") |
|
|
| import torch |
| import torchvision |
| import pytorch_lightning as pl |
| from pytorch_lightning.callbacks import Callback |
| from pytorch_lightning.utilities import rank_zero_only |
| from pytorch_lightning.utilities import rank_zero_info |
| from utils.save_video import log_local, prepare_to_log |
|
|
|
|
| class ImageLogger(Callback): |
| def __init__( |
| self, |
| batch_frequency, |
| max_images=8, |
| clamp=True, |
| rescale=True, |
| save_dir=None, |
| to_local=False, |
| log_images_kwargs=None, |
| ): |
| super().__init__() |
| self.rescale = rescale |
| self.batch_freq = batch_frequency |
| self.max_images = max_images |
| self.to_local = to_local |
| self.clamp = clamp |
| self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} |
| if self.to_local: |
| |
| self.save_dir = os.path.join(save_dir, "images") |
| os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True) |
| os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True) |
|
|
| def log_to_tensorboard(self, pl_module, batch_logs, filename, split, save_fps=10): |
| """log images and videos to tensorboard""" |
| global_step = pl_module.global_step |
| logger_type = pl_module.logger.__class__.__name__ |
|
|
| for key in batch_logs: |
| value = batch_logs[key] |
| tag = "gs%d-%s/%s-%s" % (global_step, split, filename, key) |
| if isinstance(value, list) and isinstance(value[0], str): |
| captions = " |------| ".join(value) |
| if logger_type == "WandbLogger": |
| pl_module.log_dict({tag: captions}) |
| else: |
| pl_module.logger.experiment.add_text( |
| tag, captions, global_step=global_step |
| ) |
| elif isinstance(value, torch.Tensor) and value.dim() == 5: |
| video = value |
| n = video.shape[0] |
| video = video.permute(2, 0, 1, 3, 4) |
| frame_grids = [ |
| torchvision.utils.make_grid(framesheet, nrow=int(n)) |
| for framesheet in video |
| ] |
| grid = torch.stack( |
| frame_grids, dim=0 |
| ) |
| grid = (grid + 1.0) / 2.0 |
| grid = grid.unsqueeze(dim=0) |
|
|
| if logger_type == "WandbLogger": |
| import wandb |
| grid_np = grid.cpu().numpy() |
| pl_module.logger.experiment.log({tag: wandb.Video(grid_np, fps=save_fps)}) |
| else: |
| pl_module.logger.experiment.add_video( |
| tag, grid, fps=save_fps, global_step=global_step |
| ) |
| elif isinstance(value, torch.Tensor) and value.dim() == 4: |
| img = value |
| grid = torchvision.utils.make_grid(img, nrow=int(n)) |
| grid = (grid + 1.0) / 2.0 |
|
|
| if logger_type == "WandbLogger": |
| import wandb |
| grid_np = grid.cpu().numpy() |
| pl_module.logger.experiment.log({tag: wandb.Image(grid_np)}) |
| else: |
| pl_module.logger.experiment.add_image( |
| tag, grid, global_step=global_step |
| ) |
| else: |
| pass |
|
|
| @rank_zero_only |
| def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"): |
| """generate images, then save and log to tensorboard""" |
| skip_freq = self.batch_freq if split == "train" else 5 |
| if (batch_idx + 1) % skip_freq == 0: |
| is_train = pl_module.training |
| if is_train: |
| pl_module.eval() |
|
|
| with torch.no_grad(): |
| log_func = pl_module.log_images |
| batch_logs = log_func(batch, split=split, **self.log_images_kwargs) |
|
|
| |
| batch_logs = prepare_to_log(batch_logs, self.max_images, self.clamp) |
| torch.cuda.empty_cache() |
|
|
| filename = "ep{}_idx{}_rank{}".format( |
| pl_module.current_epoch, batch_idx, pl_module.global_rank |
| ) |
| if self.to_local: |
| mainlogger.info("Log [%s] batch <%s> to local ..." % (split, filename)) |
| filename = "gs{}_".format(pl_module.global_step) + filename |
| log_local( |
| batch_logs, |
| os.path.join(self.save_dir, split), |
| filename, |
| save_fps=10, |
| ) |
| else: |
| mainlogger.info( |
| "Log [%s] batch <%s> to tensorboard ..." % (split, filename) |
| ) |
| self.log_to_tensorboard( |
| pl_module, batch_logs, filename, split, save_fps=10 |
| ) |
| mainlogger.info("Finish!") |
|
|
| if is_train: |
| pl_module.train() |
|
|
| def on_train_batch_end( |
| self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None |
| ): |
| if self.batch_freq != -1 and pl_module.logdir: |
| self.log_batch_imgs(pl_module, batch, batch_idx, split="train") |
|
|
| def on_validation_batch_end( |
| self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None |
| ): |
| |
| |
| if self.batch_freq != -1 and pl_module.logdir: |
| self.log_batch_imgs(pl_module, batch, batch_idx, split="val") |
| if hasattr(pl_module, "calibrate_grad_norm"): |
| if ( |
| pl_module.calibrate_grad_norm and batch_idx % 25 == 0 |
| ) and batch_idx > 0: |
| self.log_gradients(trainer, pl_module, batch_idx=batch_idx) |
|
|
|
|
| """ |
| class DataModeSwitcher(Callback): |
| def on_epoch_start(self, trainer, pl_module): |
| mode = 'image' if random.random() <= 0.3 else 'video' |
| trainer.datamodule.dataset.set_mode(mode) |
| if trainer.global_rank == 0: |
| torch.distributed.barrier() |
| """ |
|
|
|
|
| class CUDACallback(Callback): |
| |
| def on_train_epoch_start(self, trainer, pl_module): |
| |
| |
| if int((pl.__version__).split(".")[1]) >= 7: |
| gpu_index = trainer.strategy.root_device.index |
| else: |
| gpu_index = trainer.root_gpu |
| torch.cuda.reset_peak_memory_stats(gpu_index) |
| torch.cuda.synchronize(gpu_index) |
| self.start_time = time.time() |
|
|
| def on_train_epoch_end(self, trainer, pl_module): |
| if int((pl.__version__).split(".")[1]) >= 7: |
| gpu_index = trainer.strategy.root_device.index |
| else: |
| gpu_index = trainer.root_gpu |
| torch.cuda.synchronize(gpu_index) |
| max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20 |
| epoch_time = time.time() - self.start_time |
|
|
| try: |
| max_memory = trainer.training_type_plugin.reduce(max_memory) |
| epoch_time = trainer.training_type_plugin.reduce(epoch_time) |
|
|
| rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") |
| rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") |
| except AttributeError: |
| pass |
|
|