66a2b45
1
2
3
4
5
6
7
8
9
10
11
"""VAE utilities.""" import torch.nn as nn def zero_module(module): """Zero out the parameters of a module and return it.""" for p in module.parameters(): p.detach().zero_() return module