| from collections import OrderedDict |
| import torch |
| import torch.distributed as dist |
| from mmcv.runner import BaseModule |
| from typing import Dict, Tuple, List |
|
|
|
|
| def to_cpu(x: torch.Tensor) -> torch.Tensor: |
| """Move a tensor to CPU and detach it from the computation graph. |
| |
| Args: |
| x (torch.Tensor): The input tensor. |
| |
| Returns: |
| torch.Tensor: The tensor detached and moved to CPU. |
| """ |
| if isinstance(x, torch.Tensor): |
| return x.detach().cpu() |
| return x |
|
|
|
|
| class BaseArchitecture(BaseModule): |
| """Base class for mogen architecture. |
| |
| Args: |
| init_cfg (dict, optional): Initialization config for the module. |
| """ |
|
|
| def __init__(self, init_cfg: dict = None): |
| super(BaseArchitecture, self).__init__(init_cfg) |
|
|
| def forward_train(self, **kwargs): |
| """Forward computation during training.""" |
| pass |
|
|
| def forward_test(self, **kwargs): |
| """Forward computation during testing.""" |
| pass |
|
|
| def _parse_losses(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]: |
| """Parse the raw outputs (losses) of the network. |
| |
| Args: |
| losses (dict): Raw output of the network, which usually contains |
| losses and other necessary information. |
| |
| Returns: |
| tuple[Tensor, dict]: (loss, log_vars) |
| - loss is the loss tensor which may be a weighted sum of all losses, |
| - log_vars contains all the variables to be logged. |
| """ |
| log_vars = OrderedDict() |
| for loss_name, loss_value in losses.items(): |
| if isinstance(loss_value, torch.Tensor): |
| log_vars[loss_name] = loss_value.mean() |
| elif isinstance(loss_value, list): |
| log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) |
| else: |
| raise TypeError(f'{loss_name} is not a tensor or list of tensors') |
|
|
| loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) |
|
|
| log_vars['loss'] = loss |
| for loss_name, loss_value in log_vars.items(): |
| |
| if dist.is_available() and dist.is_initialized(): |
| loss_value = loss_value.data.clone() |
| dist.all_reduce(loss_value.div_(dist.get_world_size())) |
| log_vars[loss_name] = loss_value.item() |
|
|
| return loss, log_vars |
|
|
| def train_step(self, data: Dict, optimizer: torch.optim.Optimizer) -> Dict: |
| """The iteration step during training. |
| |
| This method defines an iteration step during training, excluding backpropagation |
| and optimizer updating, which are handled by an optimizer hook. |
| |
| Args: |
| data (dict): The output of the dataloader. |
| optimizer (torch.optim.Optimizer): The optimizer object (unused). |
| |
| Returns: |
| dict: A dictionary containing the loss, log_vars for logging, and the number of samples. |
| - ``loss``: A tensor for backpropagation, which may be a weighted sum of multiple losses. |
| - ``log_vars``: All the variables to be logged. |
| - ``num_samples``: The number of samples in the batch. |
| """ |
| losses = self(**data) |
| loss, log_vars = self._parse_losses(losses) |
|
|
| outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) |
| return outputs |
|
|
| def val_step(self, data: Dict, optimizer: torch.optim.Optimizer = None) -> Dict: |
| """The iteration step during validation. |
| |
| Args: |
| data (dict): The output of the dataloader. |
| optimizer (torch.optim.Optimizer, optional): The optimizer object (unused). |
| |
| Returns: |
| dict: A dictionary containing the loss, log_vars for logging, and the number of samples. |
| """ |
| losses = self(**data) |
| loss, log_vars = self._parse_losses(losses) |
|
|
| outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) |
| return outputs |
|
|
| def forward(self, **kwargs): |
| """Forward computation based on the training or testing mode.""" |
| if self.training: |
| return self.forward_train(**kwargs) |
| else: |
| return self.forward_test(**kwargs) |
|
|
| def split_results(self, results: Dict[str, torch.Tensor]) -> List[Dict]: |
| """Split batched results into individual outputs. |
| |
| Args: |
| results (dict): The batched results from the model containing 'motion', 'pred_motion', etc. |
| |
| Returns: |
| list: A list of dictionaries where each dictionary contains results for a single instance. |
| """ |
| B = results['motion'].shape[0] |
| output = [] |
| for i in range(B): |
| batch_output = dict() |
| batch_output['motion'] = to_cpu(results['motion'][i]) |
| batch_output['pred_motion'] = to_cpu(results['pred_motion'][i]) |
| batch_output['motion_length'] = to_cpu(results['motion_length'][i]) |
| batch_output['motion'][batch_output['motion_length']:, :] = 0 |
| batch_output['motion_mask'] = to_cpu(results['motion_mask'][i]) |
| if 'pred_motion_length' in results: |
| batch_output['pred_motion_length'] = to_cpu(results['pred_motion_length'][i]) |
| else: |
| batch_output['pred_motion_length'] = to_cpu(results['motion_length'][i]) |
| batch_output['pred_motion'][batch_output['pred_motion_length']:, :] = 0 |
| if 'pred_motion_mask' in results: |
| batch_output['pred_motion_mask'] = to_cpu(results['pred_motion_mask'][i]) |
| else: |
| batch_output['pred_motion_mask'] = to_cpu(results['motion_mask'][i]) |
| if 'motion_metas' in results: |
| motion_metas = results['motion_metas'][i] |
| if 'text' in motion_metas: |
| batch_output['text'] = motion_metas['text'] |
| if 'token' in motion_metas: |
| batch_output['token'] = motion_metas['token'] |
| if 'meta_data' in motion_metas and 'category_id' in motion_metas['meta_data']: |
| batch_output['category_id'] = motion_metas['meta_data']['category_id'] |
| batch_output['motion_metas'] = motion_metas |
| output.append(batch_output) |
| return output |
|
|