File size: 3,665 Bytes
07bfc84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ba2aed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import PreTrainedModel
from transformers.utils import ModelOutput

from .configuration_upscaler import UpscalerConfig


# -------------------------
# Architecture (same as yours)
# -------------------------

class ResidualBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.act   = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x):
        y = self.act(self.conv1(x))
        y = self.conv2(y)
        return x + y


class RestorationNet(nn.Module):
    def __init__(self, in_channels=3, width=32, num_blocks=3):
        super().__init__()
        self.in_conv  = nn.Conv2d(in_channels, width, 3, padding=1)
        self.blocks   = nn.Sequential(*[ResidualBlock(width) for _ in range(num_blocks)])
        self.out_conv = nn.Conv2d(width, in_channels, 3, padding=1)

    def forward(self, lr):
        y = self.blocks(self.in_conv(lr))
        y = self.out_conv(y)
        return lr + y


class ESPCNUpsampler(nn.Module):
    def __init__(self, in_channels=3, scale=2, feat1=64, feat2=32, use_refine=False):
        super().__init__()
        assert scale in (2, 3, 4)
        self.conv1 = nn.Conv2d(in_channels, feat1, 5, padding=2)
        self.act1  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(feat1, feat2, 3, padding=1)
        self.act2  = nn.ReLU(inplace=True)

        # IMPORTANT: conv3 out_channels depends on scale (PixelShuffle constraint)
        self.conv3 = nn.Conv2d(feat2, in_channels * (scale ** 2), 3, padding=1)
        self.ps    = nn.PixelShuffle(scale)

        self.refine = nn.Conv2d(in_channels, in_channels, 3, padding=1) if use_refine else None

    def forward(self, x):
        y = self.act1(self.conv1(x))
        y = self.act2(self.conv2(y))
        y = self.ps(self.conv3(y))
        if self.refine is not None:
            y = self.refine(y)
        return y


class TwoStageSR(nn.Module):
    def __init__(self, in_channels=3, scale=2, width=32, num_blocks=3, feat1=64, feat2=32, use_refine=False):
        super().__init__()
        self.scale = scale
        self.restoration = RestorationNet(in_channels=in_channels, width=width, num_blocks=num_blocks)
        self.upsampler = ESPCNUpsampler(
            in_channels=in_channels, scale=scale, feat1=feat1, feat2=feat2, use_refine=use_refine
        )

    def forward(self, lr):
        lr_clean = self.restoration(lr)
        hr_pred  = self.upsampler(lr_clean)
        return hr_pred


# -------------------------
# Transformers output
# -------------------------

@dataclass
class UpscalerOutput(ModelOutput):
    sr: torch.FloatTensor


class UpscalerModel(PreTrainedModel):
    config_class = UpscalerConfig
    main_input_name = "pixel_values"

    def __init__(self, config: UpscalerConfig):
        super().__init__(config)

        self.model = TwoStageSR(
            in_channels=config.in_channels,
            scale=config.scale,
            width=config.width,
            num_blocks=config.num_blocks,
            feat1=config.feat1,
            feat2=config.feat2,
            use_refine=config.use_refine,
        )

        self.post_init()

    def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput:
        """
        pixel_values: float tensor in [0,1], shape (B,3,H,W)
        returns: UpscalerOutput(sr=...)
        """
        sr = self.model(pixel_values)
        return UpscalerOutput(sr=sr)