TutlaytAI commited on
Commit
5327de9
·
verified ·
1 Parent(s): c13bcbd

Update modeling_upscaler.py

Browse files
Files changed (1) hide show
  1. modeling_upscaler.py +13 -36
modeling_upscaler.py CHANGED
@@ -96,38 +96,6 @@ class ESPCNExactUpsampler(nn.Module):
96
  return x
97
 
98
 
99
- class TwoStageSR(nn.Module):
100
- def __init__(
101
- self,
102
- in_channels=3,
103
- scale=4,
104
- width=32,
105
- num_blocks=3,
106
- espcn_channels=64,
107
- ):
108
- super().__init__()
109
- self.scale = scale
110
- self.restoration = RestorationNet(
111
- in_channels=in_channels,
112
- width=width,
113
- num_blocks=num_blocks,
114
- )
115
- self.upsampler = ESPCNExactUpsampler(
116
- scale_factor=scale,
117
- num_channels=in_channels,
118
- channels=espcn_channels,
119
- )
120
-
121
- def forward(self, lr):
122
- lr_clean = self.restoration(lr)
123
- hr_pred = self.upsampler(lr_clean)
124
- return hr_pred, lr_clean
125
-
126
-
127
- # -------------------------
128
- # Transformers output
129
- # -------------------------
130
-
131
  @dataclass
132
  class UpscalerOutput(ModelOutput):
133
  sr: torch.FloatTensor
@@ -141,16 +109,25 @@ class UpscalerModel(PreTrainedModel):
141
  def __init__(self, config: UpscalerConfig):
142
  super().__init__(config)
143
 
144
- self.model = TwoStageSR(
 
 
 
 
 
145
  in_channels=config.in_channels,
146
- scale=config.scale,
147
  width=config.width,
148
  num_blocks=config.num_blocks,
149
- espcn_channels=config.espcn_channels,
 
 
 
 
150
  )
151
 
152
  self.post_init()
153
 
154
  def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput:
155
- sr, restored = self.model(pixel_values)
 
156
  return UpscalerOutput(sr=sr, restored=restored)
 
96
  return x
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  @dataclass
100
  class UpscalerOutput(ModelOutput):
101
  sr: torch.FloatTensor
 
109
  def __init__(self, config: UpscalerConfig):
110
  super().__init__(config)
111
 
112
+ # IMPORTANT :
113
+ # on expose directement les sous-modules pour que les clés
114
+ # correspondent au checkpoint Lightning :
115
+ # restoration.* / upsampler.*
116
+ self.scale = config.scale
117
+ self.restoration = RestorationNet(
118
  in_channels=config.in_channels,
 
119
  width=config.width,
120
  num_blocks=config.num_blocks,
121
+ )
122
+ self.upsampler = ESPCNExactUpsampler(
123
+ scale_factor=config.scale,
124
+ num_channels=config.in_channels,
125
+ channels=config.espcn_channels,
126
  )
127
 
128
  self.post_init()
129
 
130
  def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput:
131
+ restored = self.restoration(pixel_values)
132
+ sr = self.upsampler(restored)
133
  return UpscalerOutput(sr=sr, restored=restored)