U-Past / modules /head.py
lycaoduong's picture
Initial space
e8160b2 verified
import torch
import torch.nn as nn
from .blocks.complexblock import ComplexConv2d, ComplexBatchNorm
class CVHead(nn.Module):
def __init__(self, in_channels, out_channels):
super(CVHead, self).__init__()
# self.up = nn.ConvTranspose2d(in_chan, in_chan, kernel_size=(2, 2), stride=(2, 2))
self.conv_last = ComplexConv2d(in_channels, out_channels, kernel_size=(1, 1))
self.bn = ComplexBatchNorm(out_channels)
def forward(self, x):
out = self.conv_last(x)
return out
class CVVOCOSHead(nn.Module):
def __init__(self, in_channels, out_channels, complex_axis=1):
super(CVVOCOSHead, self).__init__()
self.real_out = nn.Linear(in_channels, out_channels)
self.imag_out = nn.Linear(in_channels, out_channels)
self.complex_axis = complex_axis
def forward(self, x):
real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts
real_out = self.real_out(real.transpose(1, 2)).transpose(1, 2) # Apply linear layer to real part
imag_out = self.imag_out(imag.transpose(1, 2)).transpose(1, 2) # Apply linear layer to imaginary part
out = torch.stack([real_out, imag_out], dim=self.complex_axis) # Concatenate real and imaginary parts back together
return out