from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.utils import ModelOutput from .configuration_upscaler import UpscalerConfig # ------------------------- # Architecture (same as yours) # ------------------------- class ResidualBlock(nn.Module): def __init__(self, channels: int): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.act = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x): y = self.act(self.conv1(x)) y = self.conv2(y) return x + y class RestorationNet(nn.Module): def __init__(self, in_channels=3, width=32, num_blocks=3): super().__init__() self.in_conv = nn.Conv2d(in_channels, width, 3, padding=1) self.blocks = nn.Sequential(*[ResidualBlock(width) for _ in range(num_blocks)]) self.out_conv = nn.Conv2d(width, in_channels, 3, padding=1) def forward(self, lr): y = self.blocks(self.in_conv(lr)) y = self.out_conv(y) return lr + y class ESPCNUpsampler(nn.Module): def __init__(self, in_channels=3, scale=2, feat1=64, feat2=32, use_refine=False): super().__init__() assert scale in (2, 3, 4) self.conv1 = nn.Conv2d(in_channels, feat1, 5, padding=2) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(feat1, feat2, 3, padding=1) self.act2 = nn.ReLU(inplace=True) # IMPORTANT: conv3 out_channels depends on scale (PixelShuffle constraint) self.conv3 = nn.Conv2d(feat2, in_channels * (scale ** 2), 3, padding=1) self.ps = nn.PixelShuffle(scale) self.refine = nn.Conv2d(in_channels, in_channels, 3, padding=1) if use_refine else None def forward(self, x): y = self.act1(self.conv1(x)) y = self.act2(self.conv2(y)) y = self.ps(self.conv3(y)) if self.refine is not None: y = self.refine(y) return y class TwoStageSR(nn.Module): def __init__(self, in_channels=3, scale=2, width=32, num_blocks=3, feat1=64, feat2=32, use_refine=False): super().__init__() self.scale = scale self.restoration = RestorationNet(in_channels=in_channels, width=width, num_blocks=num_blocks) self.upsampler = ESPCNUpsampler( in_channels=in_channels, scale=scale, feat1=feat1, feat2=feat2, use_refine=use_refine ) def forward(self, lr): lr_clean = self.restoration(lr) hr_pred = self.upsampler(lr_clean) return hr_pred # ------------------------- # Transformers output # ------------------------- @dataclass class UpscalerOutput(ModelOutput): sr: torch.FloatTensor class UpscalerModel(PreTrainedModel): config_class = UpscalerConfig main_input_name = "pixel_values" def __init__(self, config: UpscalerConfig): super().__init__(config) self.model = TwoStageSR( in_channels=config.in_channels, scale=config.scale, width=config.width, num_blocks=config.num_blocks, feat1=config.feat1, feat2=config.feat2, use_refine=config.use_refine, ) self.post_init() def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput: """ pixel_values: float tensor in [0,1], shape (B,3,H,W) returns: UpscalerOutput(sr=...) """ sr = self.model(pixel_values) return UpscalerOutput(sr=sr)