| """VAE utilities.""" | |
| import torch.nn as nn | |
| def zero_module(module): | |
| """Zero out the parameters of a module and return it.""" | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| """VAE utilities.""" | |
| import torch.nn as nn | |
| def zero_module(module): | |
| """Zero out the parameters of a module and return it.""" | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |