import torch import torch.nn as nn from .blocks.complexblock import ComplexConv2d, ComplexBatchNorm class CVHead(nn.Module): def __init__(self, in_channels, out_channels): super(CVHead, self).__init__() # self.up = nn.ConvTranspose2d(in_chan, in_chan, kernel_size=(2, 2), stride=(2, 2)) self.conv_last = ComplexConv2d(in_channels, out_channels, kernel_size=(1, 1)) self.bn = ComplexBatchNorm(out_channels) def forward(self, x): out = self.conv_last(x) return out class CVVOCOSHead(nn.Module): def __init__(self, in_channels, out_channels, complex_axis=1): super(CVVOCOSHead, self).__init__() self.real_out = nn.Linear(in_channels, out_channels) self.imag_out = nn.Linear(in_channels, out_channels) self.complex_axis = complex_axis def forward(self, x): real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts real_out = self.real_out(real.transpose(1, 2)).transpose(1, 2) # Apply linear layer to real part imag_out = self.imag_out(imag.transpose(1, 2)).transpose(1, 2) # Apply linear layer to imaginary part out = torch.stack([real_out, imag_out], dim=self.complex_axis) # Concatenate real and imaginary parts back together return out