"""VAE utilities.""" import torch.nn as nn def zero_module(module): """Zero out the parameters of a module and return it.""" for p in module.parameters(): p.detach().zero_() return module