File size: 211 Bytes
66a2b45
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
"""VAE utilities."""

import torch.nn as nn


def zero_module(module):
    """Zero out the parameters of a module and return it."""
    for p in module.parameters():
        p.detach().zero_()
    return module