File size: 765 Bytes
dfd1909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConditionalLayerNormalization(nn.Module):
def __init__(self,
input_channels:int,
style_condition_channels:int) -> None:
super().__init__()
self.gain_bias_conv1d = nn.Conv1d(style_condition_channels, input_channels * 2, 1)
def forward(self,
input:Tensor, #[batch,input_channels,N]
style_condition:Tensor #[batch,style_condition_channels,N]
):
normalized_input:Tensor = F.layer_norm(input,normalized_shape=input.shape[1:])
weight, bias = self.gain_bias_conv1d(style_condition).chunk(2, dim=1)
return normalized_input * weight + bias |