ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConditionalLayerNormalization(nn.Module):
def __init__(self,
input_channels:int,
style_condition_channels:int) -> None:
super().__init__()
self.gain_bias_conv1d = nn.Conv1d(style_condition_channels, input_channels * 2, 1)
def forward(self,
input:Tensor, #[batch,input_channels,N]
style_condition:Tensor #[batch,style_condition_channels,N]
):
normalized_input:Tensor = F.layer_norm(input,normalized_shape=input.shape[1:])
weight, bias = self.gain_bias_conv1d(style_condition).chunk(2, dim=1)
return normalized_input * weight + bias