| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from .feature_alignment import FeatureAlignmentBlock |
| from .sr_network import SRNetwork |
|
|
| class ERSVR(nn.Module): |
| """Real-time Video Super Resolution Network using Recurrent Multi-Branch Dilated Convolutions""" |
| def __init__(self, scale_factor=4): |
| super(ERSVR, self).__init__() |
| |
| self.scale_factor = scale_factor |
| |
| |
| self.feature_alignment = FeatureAlignmentBlock(in_channels=9, out_channels=64) |
| |
| |
| self.sr_network = SRNetwork(in_channels=64, out_channels=3) |
| |
| def forward(self, x): |
| |
| batch_size, num_frames, channels, height, width = x.shape |
| |
| |
| x = rearrange(x, 'b n c h w -> b (n c) h w') |
| |
| |
| center_frame = x[:, 3:6, :, :] |
| |
| |
| bicubic = F.interpolate( |
| center_frame, |
| scale_factor=self.scale_factor, |
| mode='bicubic', |
| align_corners=False |
| ) |
| |
| |
| features = self.feature_alignment(x) |
| |
| |
| output = self.sr_network(features, bicubic) |
| |
| |
| if output.shape != bicubic.shape: |
| print(f"Output shape: {output.shape}, Bicubic shape: {bicubic.shape}") |
| raise ValueError("Output and bicubic tensors must have the same dimensions") |
| |
| return output |