Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| def mi(x: torch.Tensor) -> torch.Tensor: | |
| return torch.sum(x, dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3]) | |
| def sigma(x: torch.Tensor, epsilon=1e-5) -> torch.Tensor: | |
| return torch.sqrt(torch.sum(((x - mi(x))**2 + epsilon), dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3])) | |
| class AdaIN(nn.Module): | |
| def __init__(self, epsilon=1e-5): | |
| super().__init__() | |
| self.epsilon = epsilon | |
| def forward(self, content: torch.Tensor, style: torch.Tensor) -> torch.Tensor: | |
| return (torch.mul(sigma(style, self.epsilon), ((content - mi(content)) / sigma(content, self.epsilon))) + mi(style)) | |