| from typing import * |
| import copy |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from easydict import EasyDict as edict |
|
|
| from ..basic import BasicTrainer |
|
|
|
|
| class SparseStructureVaeTrainer(BasicTrainer): |
| """ |
| Trainer for Sparse Structure VAE. |
| |
| Args: |
| models (dict[str, nn.Module]): Models to train. |
| dataset (torch.utils.data.Dataset): Dataset. |
| output_dir (str): Output directory. |
| load_dir (str): Load directory. |
| step (int): Step to load. |
| batch_size (int): Batch size. |
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. |
| batch_split (int): Split batch with gradient accumulation. |
| max_steps (int): Max steps. |
| optimizer (dict): Optimizer config. |
| lr_scheduler (dict): Learning rate scheduler config. |
| elastic (dict): Elastic memory management config. |
| grad_clip (float or dict): Gradient clip config. |
| ema_rate (float or list): Exponential moving average rates. |
| fp16_mode (str): FP16 mode. |
| - None: No FP16. |
| - 'inflat_all': Hold a inflated fp32 master param for all params. |
| - 'amp': Automatic mixed precision. |
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. |
| finetune_ckpt (dict): Finetune checkpoint. |
| log_param_stats (bool): Log parameter stats. |
| i_print (int): Print interval. |
| i_log (int): Log interval. |
| i_sample (int): Sample interval. |
| i_save (int): Save interval. |
| i_ddpcheck (int): DDP check interval. |
| |
| loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss. |
| lambda_kl (float): KL divergence loss weight. |
| """ |
| |
| def __init__( |
| self, |
| *args, |
| loss_type='bce', |
| lambda_kl=1e-6, |
| **kwargs |
| ): |
| super().__init__(*args, **kwargs) |
| self.loss_type = loss_type |
| self.lambda_kl = lambda_kl |
| |
| def training_losses( |
| self, |
| ss: torch.Tensor, |
| **kwargs |
| ) -> Tuple[Dict, Dict]: |
| """ |
| Compute training losses. |
| |
| Args: |
| ss: The [N x 1 x H x W x D] tensor of binary sparse structure. |
| |
| Returns: |
| a dict with the key "loss" containing a scalar tensor. |
| may also contain other keys for different terms. |
| """ |
| z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True) |
| logits = self.training_models['decoder'](z) |
|
|
| terms = edict(loss = 0.0) |
| if self.loss_type == 'bce': |
| terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean') |
| terms["loss"] = terms["loss"] + terms["bce"] |
| elif self.loss_type == 'l1': |
| terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean') |
| terms["loss"] = terms["loss"] + terms["l1"] |
| elif self.loss_type == 'dice': |
| logits = F.sigmoid(logits) |
| terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1) |
| terms["loss"] = terms["loss"] + terms["dice"] |
| else: |
| raise ValueError(f'Invalid loss type {self.loss_type}') |
| terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) |
| terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] |
| |
| return terms, {} |
| |
| @torch.no_grad() |
| def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False): |
| super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose) |
| |
| @torch.no_grad() |
| def run_snapshot( |
| self, |
| num_samples: int, |
| batch_size: int, |
| verbose: bool = False, |
| ) -> Dict: |
| dataloader = DataLoader( |
| copy.deepcopy(self.dataset), |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=0, |
| collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, |
| ) |
|
|
| |
| gts = [] |
| recons = [] |
| for i in range(0, num_samples, batch_size): |
| batch = min(batch_size, num_samples - i) |
| data = next(iter(dataloader)) |
| args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} |
| z = self.models['encoder'](args['ss'].float(), sample_posterior=False) |
| logits = self.models['decoder'](z) |
| recon = (logits > 0).long() |
| gts.append(args['ss']) |
| recons.append(recon) |
|
|
| sample_dict = { |
| 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'}, |
| 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'}, |
| } |
| return sample_dict |
|
|