Upload folder using huggingface_hub
Browse files- ersvr/models/ersvr.py +49 -0
- ersvr/models/feature_alignment.py +24 -0
- ersvr/models/mbd.py +28 -0
- ersvr/models/sr_network.py +44 -0
- ersvr/models/student.py +59 -0
- ersvr/models/upsampling.py +33 -0
ersvr/models/ersvr.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from .feature_alignment import FeatureAlignmentBlock
|
| 5 |
+
from .sr_network import SRNetwork
|
| 6 |
+
|
| 7 |
+
class ERSVR(nn.Module):
|
| 8 |
+
"""Real-time Video Super Resolution Network using Recurrent Multi-Branch Dilated Convolutions"""
|
| 9 |
+
def __init__(self, scale_factor=4):
|
| 10 |
+
super(ERSVR, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.scale_factor = scale_factor
|
| 13 |
+
|
| 14 |
+
# Feature alignment block
|
| 15 |
+
self.feature_alignment = FeatureAlignmentBlock(in_channels=9, out_channels=64)
|
| 16 |
+
|
| 17 |
+
# SR network
|
| 18 |
+
self.sr_network = SRNetwork(in_channels=64, out_channels=3)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
# Input shape: (B, 3, 3, H, W) - batch of 3 RGB frames
|
| 22 |
+
batch_size, num_frames, channels, height, width = x.shape
|
| 23 |
+
|
| 24 |
+
# Rearrange input to (B, 9, H, W)
|
| 25 |
+
x = rearrange(x, 'b n c h w -> b (n c) h w')
|
| 26 |
+
|
| 27 |
+
# Extract center frame for residual connection
|
| 28 |
+
center_frame = x[:, 3:6, :, :] # RGB channels of center frame
|
| 29 |
+
|
| 30 |
+
# Bicubic upsampling of center frame for residual connection
|
| 31 |
+
bicubic = F.interpolate(
|
| 32 |
+
center_frame,
|
| 33 |
+
scale_factor=self.scale_factor,
|
| 34 |
+
mode='bicubic',
|
| 35 |
+
align_corners=False
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Feature alignment
|
| 39 |
+
features = self.feature_alignment(x)
|
| 40 |
+
|
| 41 |
+
# SR network
|
| 42 |
+
output = self.sr_network(features, bicubic)
|
| 43 |
+
|
| 44 |
+
# Ensure output and bicubic have the same dimensions
|
| 45 |
+
if output.shape != bicubic.shape:
|
| 46 |
+
print(f"Output shape: {output.shape}, Bicubic shape: {bicubic.shape}")
|
| 47 |
+
raise ValueError("Output and bicubic tensors must have the same dimensions")
|
| 48 |
+
|
| 49 |
+
return output
|
ersvr/models/feature_alignment.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .mbd import MBDModule
|
| 3 |
+
|
| 4 |
+
class FeatureAlignmentBlock(nn.Module):
|
| 5 |
+
"""Feature Alignment Block for processing concatenated frames"""
|
| 6 |
+
def __init__(self, in_channels=9, out_channels=64):
|
| 7 |
+
super(FeatureAlignmentBlock, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.conv_layers = nn.Sequential(
|
| 10 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 11 |
+
nn.ReLU(inplace=True),
|
| 12 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 13 |
+
nn.ReLU(inplace=True),
|
| 14 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 15 |
+
nn.ReLU(inplace=True)
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
self.mbd = MBDModule(out_channels, out_channels)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
# Input shape: (B, 9, H, W) - concatenated frames
|
| 22 |
+
x = self.conv_layers(x)
|
| 23 |
+
x = self.mbd(x)
|
| 24 |
+
return x
|
ersvr/models/mbd.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class MBDModule(nn.Module):
|
| 5 |
+
"""Multi-Branch Dilated Convolution Module"""
|
| 6 |
+
def __init__(self, in_channels, out_channels):
|
| 7 |
+
super(MBDModule, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 10 |
+
|
| 11 |
+
self.dilated_convs = nn.ModuleList([
|
| 12 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
| 13 |
+
padding=d, dilation=d) for d in [1, 2, 4]
|
| 14 |
+
])
|
| 15 |
+
|
| 16 |
+
self.fusion = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x = self.pointwise(x)
|
| 20 |
+
|
| 21 |
+
dilated_outputs = []
|
| 22 |
+
for conv in self.dilated_convs:
|
| 23 |
+
dilated_outputs.append(conv(x))
|
| 24 |
+
|
| 25 |
+
x = torch.cat(dilated_outputs, dim=1)
|
| 26 |
+
x = self.fusion(x)
|
| 27 |
+
|
| 28 |
+
return x
|
ersvr/models/sr_network.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .upsampling import UpsamplingBlock
|
| 4 |
+
|
| 5 |
+
class SRNetwork(nn.Module):
|
| 6 |
+
"""Super Resolution Network with ESPCN-like backbone"""
|
| 7 |
+
def __init__(self, in_channels=64, out_channels=3):
|
| 8 |
+
super(SRNetwork, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.conv_layers = nn.Sequential(
|
| 11 |
+
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
|
| 12 |
+
nn.ReLU(inplace=True),
|
| 13 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 14 |
+
nn.ReLU(inplace=True),
|
| 15 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 16 |
+
nn.ReLU(inplace=True),
|
| 17 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 18 |
+
nn.ReLU(inplace=True),
|
| 19 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 20 |
+
nn.ReLU(inplace=True),
|
| 21 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 22 |
+
nn.ReLU(inplace=True),
|
| 23 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 24 |
+
nn.ReLU(inplace=True),
|
| 25 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 28 |
+
nn.ReLU(inplace=True)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.upsampling = UpsamplingBlock(64)
|
| 32 |
+
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
|
| 33 |
+
|
| 34 |
+
def forward(self, x, bicubic):
|
| 35 |
+
x = self.conv_layers(x)
|
| 36 |
+
|
| 37 |
+
print(f"Before upsampling: {x.shape}")
|
| 38 |
+
x = self.upsampling(x)
|
| 39 |
+
print(f"After upsampling: {x.shape}")
|
| 40 |
+
print(f"Bicubic shape: {bicubic.shape}")
|
| 41 |
+
|
| 42 |
+
x = self.final_conv(x)
|
| 43 |
+
x = x + bicubic
|
| 44 |
+
return x
|
ersvr/models/student.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 4 |
+
"""
|
| 5 |
+
Depthwise Separable Convolution Block for efficiency.
|
| 6 |
+
Consists of a depthwise convolution followed by a pointwise convolution.
|
| 7 |
+
"""
|
| 8 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
|
| 11 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
|
| 12 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 13 |
+
self.relu = nn.ReLU(inplace=True)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
x = self.depthwise(x)
|
| 17 |
+
x = self.pointwise(x)
|
| 18 |
+
x = self.bn(x)
|
| 19 |
+
x = self.relu(x)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
class StudentSRNet(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Ultra-lightweight Student Model for Video Super-Resolution.
|
| 25 |
+
- Input: (B, 3, 3, H, W) # 3 frames, 3 channels each
|
| 26 |
+
- Output: (B, 3, H*4, W*4) # Super-resolved center frame
|
| 27 |
+
Designed for real-time, mobile/edge deployment.
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, scale_factor=4):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.scale_factor = scale_factor
|
| 32 |
+
self.input_conv = nn.Conv2d(9, 16, 3, padding=1)
|
| 33 |
+
self.block1 = DepthwiseSeparableConv(16, 32)
|
| 34 |
+
self.block2 = DepthwiseSeparableConv(32, 32)
|
| 35 |
+
self.block3 = DepthwiseSeparableConv(32, 16)
|
| 36 |
+
self.upsample1 = nn.Sequential(
|
| 37 |
+
nn.Conv2d(16, 64, 3, padding=1),
|
| 38 |
+
nn.PixelShuffle(2),
|
| 39 |
+
nn.ReLU(inplace=True)
|
| 40 |
+
)
|
| 41 |
+
self.upsample2 = nn.Sequential(
|
| 42 |
+
nn.Conv2d(16, 64, 3, padding=1),
|
| 43 |
+
nn.PixelShuffle(2),
|
| 44 |
+
nn.ReLU(inplace=True)
|
| 45 |
+
)
|
| 46 |
+
self.output_conv = nn.Conv2d(16, 3, 3, padding=1)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
# x: (B, 3, 3, H, W) -> (B, 9, H, W)
|
| 50 |
+
b, n, c, h, w = x.shape
|
| 51 |
+
x = x.reshape(b, n * c, h, w)
|
| 52 |
+
x = self.input_conv(x)
|
| 53 |
+
x = self.block1(x)
|
| 54 |
+
x = self.block2(x)
|
| 55 |
+
x = self.block3(x)
|
| 56 |
+
x = self.upsample1(x)
|
| 57 |
+
x = self.upsample2(x)
|
| 58 |
+
x = self.output_conv(x)
|
| 59 |
+
return x
|
ersvr/models/upsampling.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class SubpixelUpsampling(nn.Module):
|
| 4 |
+
"""Subpixel Upsampling Module using PixelShuffle"""
|
| 5 |
+
def __init__(self, in_channels, scale_factor=2):
|
| 6 |
+
super(SubpixelUpsampling, self).__init__()
|
| 7 |
+
|
| 8 |
+
self.scale_factor = scale_factor
|
| 9 |
+
self.conv = nn.Conv2d(
|
| 10 |
+
in_channels,
|
| 11 |
+
in_channels * (scale_factor ** 2),
|
| 12 |
+
kernel_size=3,
|
| 13 |
+
padding=1
|
| 14 |
+
)
|
| 15 |
+
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.conv(x)
|
| 19 |
+
x = self.pixel_shuffle(x)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
class UpsamplingBlock(nn.Module):
|
| 23 |
+
"""Block for 4x upsampling using two SubpixelUpsampling modules"""
|
| 24 |
+
def __init__(self, in_channels):
|
| 25 |
+
super(UpsamplingBlock, self).__init__()
|
| 26 |
+
|
| 27 |
+
self.upsample1 = SubpixelUpsampling(in_channels)
|
| 28 |
+
self.upsample2 = SubpixelUpsampling(in_channels)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
x = self.upsample1(x)
|
| 32 |
+
x = self.upsample2(x)
|
| 33 |
+
return x
|