Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class AdaIN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, y): | |
| ch = y.size(1) | |
| sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1) | |
| x_mu = x.mean(dim=[2, 3], keepdim=True) | |
| x_sigma = x.std(dim=[2, 3], keepdim=True) | |
| return sigma * ((x - x_mu) / x_sigma) + mu | |