Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from .blocks.complexblock import ComplexConv2d, ComplexBatchNorm | |
| class CVHead(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(CVHead, self).__init__() | |
| # self.up = nn.ConvTranspose2d(in_chan, in_chan, kernel_size=(2, 2), stride=(2, 2)) | |
| self.conv_last = ComplexConv2d(in_channels, out_channels, kernel_size=(1, 1)) | |
| self.bn = ComplexBatchNorm(out_channels) | |
| def forward(self, x): | |
| out = self.conv_last(x) | |
| return out | |
| class CVVOCOSHead(nn.Module): | |
| def __init__(self, in_channels, out_channels, complex_axis=1): | |
| super(CVVOCOSHead, self).__init__() | |
| self.real_out = nn.Linear(in_channels, out_channels) | |
| self.imag_out = nn.Linear(in_channels, out_channels) | |
| self.complex_axis = complex_axis | |
| def forward(self, x): | |
| real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts | |
| real_out = self.real_out(real.transpose(1, 2)).transpose(1, 2) # Apply linear layer to real part | |
| imag_out = self.imag_out(imag.transpose(1, 2)).transpose(1, 2) # Apply linear layer to imaginary part | |
| out = torch.stack([real_out, imag_out], dim=self.complex_axis) # Concatenate real and imaginary parts back together | |
| return out | |