import torch.nn as nn from .mbd import MBDModule class FeatureAlignmentBlock(nn.Module): """Feature Alignment Block for processing concatenated frames""" def __init__(self, in_channels=9, out_channels=64): super(FeatureAlignmentBlock, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.mbd = MBDModule(out_channels, out_channels) def forward(self, x): # Input shape: (B, 9, H, W) - concatenated frames x = self.conv_layers(x) x = self.mbd(x) return x