| def set_requires_grad(nets, requires_grad=False): | |
| """Set requies_grad for all the networks. | |
| Args: | |
| nets (nn.Module | list[nn.Module]): A list of networks or a single | |
| network. | |
| requires_grad (bool): Whether the networks require gradients or not | |
| """ | |
| if not isinstance(nets, list): | |
| nets = [nets] | |
| for net in nets: | |
| if net is not None: | |
| for param in net.parameters(): | |
| param.requires_grad = requires_grad | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |