| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.utils import ModelOutput |
| |
|
| | from .configuration_upscaler import UpscalerConfig |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, channels: int): |
| | super().__init__() |
| | self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
| | self.act = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
| |
|
| | def forward(self, x): |
| | y = self.act(self.conv1(x)) |
| | y = self.conv2(y) |
| | return x + y |
| |
|
| |
|
| | class RestorationNet(nn.Module): |
| | def __init__(self, in_channels=3, width=32, num_blocks=3): |
| | super().__init__() |
| | self.in_conv = nn.Conv2d(in_channels, width, 3, padding=1) |
| | self.blocks = nn.Sequential(*[ResidualBlock(width) for _ in range(num_blocks)]) |
| | self.out_conv = nn.Conv2d(width, in_channels, 3, padding=1) |
| |
|
| | def forward(self, lr): |
| | y = self.blocks(self.in_conv(lr)) |
| | y = self.out_conv(y) |
| | return lr + y |
| |
|
| |
|
| | class ESPCNUpsampler(nn.Module): |
| | def __init__(self, in_channels=3, scale=2, feat1=64, feat2=32, use_refine=False): |
| | super().__init__() |
| | assert scale in (2, 3, 4) |
| | self.conv1 = nn.Conv2d(in_channels, feat1, 5, padding=2) |
| | self.act1 = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(feat1, feat2, 3, padding=1) |
| | self.act2 = nn.ReLU(inplace=True) |
| |
|
| | |
| | self.conv3 = nn.Conv2d(feat2, in_channels * (scale ** 2), 3, padding=1) |
| | self.ps = nn.PixelShuffle(scale) |
| |
|
| | self.refine = nn.Conv2d(in_channels, in_channels, 3, padding=1) if use_refine else None |
| |
|
| | def forward(self, x): |
| | y = self.act1(self.conv1(x)) |
| | y = self.act2(self.conv2(y)) |
| | y = self.ps(self.conv3(y)) |
| | if self.refine is not None: |
| | y = self.refine(y) |
| | return y |
| |
|
| |
|
| | class TwoStageSR(nn.Module): |
| | def __init__(self, in_channels=3, scale=2, width=32, num_blocks=3, feat1=64, feat2=32, use_refine=False): |
| | super().__init__() |
| | self.scale = scale |
| | self.restoration = RestorationNet(in_channels=in_channels, width=width, num_blocks=num_blocks) |
| | self.upsampler = ESPCNUpsampler( |
| | in_channels=in_channels, scale=scale, feat1=feat1, feat2=feat2, use_refine=use_refine |
| | ) |
| |
|
| | def forward(self, lr): |
| | lr_clean = self.restoration(lr) |
| | hr_pred = self.upsampler(lr_clean) |
| | return hr_pred |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class UpscalerOutput(ModelOutput): |
| | sr: torch.FloatTensor |
| |
|
| |
|
| | class UpscalerModel(PreTrainedModel): |
| | config_class = UpscalerConfig |
| | main_input_name = "pixel_values" |
| |
|
| | def __init__(self, config: UpscalerConfig): |
| | super().__init__(config) |
| |
|
| | self.model = TwoStageSR( |
| | in_channels=config.in_channels, |
| | scale=config.scale, |
| | width=config.width, |
| | num_blocks=config.num_blocks, |
| | feat1=config.feat1, |
| | feat2=config.feat2, |
| | use_refine=config.use_refine, |
| | ) |
| |
|
| | self.post_init() |
| |
|
| | def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput: |
| | """ |
| | pixel_values: float tensor in [0,1], shape (B,3,H,W) |
| | returns: UpscalerOutput(sr=...) |
| | """ |
| | sr = self.model(pixel_values) |
| | return UpscalerOutput(sr=sr) |