Update modeling_upscaler.py
Browse files- modeling_upscaler.py +13 -36
modeling_upscaler.py
CHANGED
|
@@ -96,38 +96,6 @@ class ESPCNExactUpsampler(nn.Module):
|
|
| 96 |
return x
|
| 97 |
|
| 98 |
|
| 99 |
-
class TwoStageSR(nn.Module):
|
| 100 |
-
def __init__(
|
| 101 |
-
self,
|
| 102 |
-
in_channels=3,
|
| 103 |
-
scale=4,
|
| 104 |
-
width=32,
|
| 105 |
-
num_blocks=3,
|
| 106 |
-
espcn_channels=64,
|
| 107 |
-
):
|
| 108 |
-
super().__init__()
|
| 109 |
-
self.scale = scale
|
| 110 |
-
self.restoration = RestorationNet(
|
| 111 |
-
in_channels=in_channels,
|
| 112 |
-
width=width,
|
| 113 |
-
num_blocks=num_blocks,
|
| 114 |
-
)
|
| 115 |
-
self.upsampler = ESPCNExactUpsampler(
|
| 116 |
-
scale_factor=scale,
|
| 117 |
-
num_channels=in_channels,
|
| 118 |
-
channels=espcn_channels,
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
def forward(self, lr):
|
| 122 |
-
lr_clean = self.restoration(lr)
|
| 123 |
-
hr_pred = self.upsampler(lr_clean)
|
| 124 |
-
return hr_pred, lr_clean
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# -------------------------
|
| 128 |
-
# Transformers output
|
| 129 |
-
# -------------------------
|
| 130 |
-
|
| 131 |
@dataclass
|
| 132 |
class UpscalerOutput(ModelOutput):
|
| 133 |
sr: torch.FloatTensor
|
|
@@ -141,16 +109,25 @@ class UpscalerModel(PreTrainedModel):
|
|
| 141 |
def __init__(self, config: UpscalerConfig):
|
| 142 |
super().__init__(config)
|
| 143 |
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
in_channels=config.in_channels,
|
| 146 |
-
scale=config.scale,
|
| 147 |
width=config.width,
|
| 148 |
num_blocks=config.num_blocks,
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
self.post_init()
|
| 153 |
|
| 154 |
def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput:
|
| 155 |
-
|
|
|
|
| 156 |
return UpscalerOutput(sr=sr, restored=restored)
|
|
|
|
| 96 |
return x
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
@dataclass
|
| 100 |
class UpscalerOutput(ModelOutput):
|
| 101 |
sr: torch.FloatTensor
|
|
|
|
| 109 |
def __init__(self, config: UpscalerConfig):
|
| 110 |
super().__init__(config)
|
| 111 |
|
| 112 |
+
# IMPORTANT :
|
| 113 |
+
# on expose directement les sous-modules pour que les clés
|
| 114 |
+
# correspondent au checkpoint Lightning :
|
| 115 |
+
# restoration.* / upsampler.*
|
| 116 |
+
self.scale = config.scale
|
| 117 |
+
self.restoration = RestorationNet(
|
| 118 |
in_channels=config.in_channels,
|
|
|
|
| 119 |
width=config.width,
|
| 120 |
num_blocks=config.num_blocks,
|
| 121 |
+
)
|
| 122 |
+
self.upsampler = ESPCNExactUpsampler(
|
| 123 |
+
scale_factor=config.scale,
|
| 124 |
+
num_channels=config.in_channels,
|
| 125 |
+
channels=config.espcn_channels,
|
| 126 |
)
|
| 127 |
|
| 128 |
self.post_init()
|
| 129 |
|
| 130 |
def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput:
|
| 131 |
+
restored = self.restoration(pixel_values)
|
| 132 |
+
sr = self.upsampler(restored)
|
| 133 |
return UpscalerOutput(sr=sr, restored=restored)
|