Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| class DataParallel(nn.DataParallel): | |
| """ | |
| Overview: | |
| A wrapper class for nn.DataParallel. | |
| Interfaces: | |
| ``__init__``, ``parameters`` | |
| """ | |
| def __init__(self, module, device_ids=None, output_device=None, dim=0): | |
| """ | |
| Overview: | |
| Initialize the DataParallel object. | |
| Arguments: | |
| - module (:obj:`nn.Module`): The module to be parallelized. | |
| - device_ids (:obj:`list`): The list of GPU ids. | |
| - output_device (:obj:`int`): The output GPU id. | |
| - dim (:obj:`int`): The dimension to be parallelized. | |
| """ | |
| super().__init__(module, device_ids=None, output_device=None, dim=0) | |
| self.module = module | |
| def parameters(self, recurse: bool = True): | |
| """ | |
| Overview: | |
| Return the parameters of the module. | |
| Arguments: | |
| - recurse (:obj:`bool`): Whether to return the parameters of the submodules. | |
| Returns: | |
| - params (:obj:`generator`): The generator of the parameters. | |
| """ | |
| return self.module.parameters(recurse=True) | |