styletalk / core /networks /dynamic_linear.py
ameerazam08's picture
Upload folder using huggingface_hub
9a973f2
import math
import torch
from torch import nn
from torch.nn import functional as F
from core.networks.dynamic_conv import DynamicConv
class DynamicLinear(nn.Module):
def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
super().__init__()
self.dynamic_conv = DynamicConv(
in_planes,
out_planes,
cond_planes,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
K=K,
ratio=ratio,
temperature=temperature,
init_weight=init_weight,
)
def forward(self, x, cond):
"""
Args:
x (_type_): (L, B, C_in)
cond (_type_): (B, C_style)
Returns:
_type_: (L, B, C_out)
"""
x = x.permute(1, 2, 0).unsqueeze(-1)
out = self.dynamic_conv(x, cond)
# (B, C_out, L, 1)
out = out.squeeze().permute(2, 0, 1)
return out