SeeSharp / ersvr /models /feature_alignment.py
Abhinavexists's picture
Upload folder using huggingface_hub
5b9bb29 verified
import torch.nn as nn
from .mbd import MBDModule
class FeatureAlignmentBlock(nn.Module):
"""Feature Alignment Block for processing concatenated frames"""
def __init__(self, in_channels=9, out_channels=64):
super(FeatureAlignmentBlock, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.mbd = MBDModule(out_channels, out_channels)
def forward(self, x):
# Input shape: (B, 9, H, W) - concatenated frames
x = self.conv_layers(x)
x = self.mbd(x)
return x