File size: 895 Bytes
5b9bb29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch.nn as nn
from .mbd import MBDModule
class FeatureAlignmentBlock(nn.Module):
"""Feature Alignment Block for processing concatenated frames"""
def __init__(self, in_channels=9, out_channels=64):
super(FeatureAlignmentBlock, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.mbd = MBDModule(out_channels, out_channels)
def forward(self, x):
# Input shape: (B, 9, H, W) - concatenated frames
x = self.conv_layers(x)
x = self.mbd(x)
return x |