ihtesham0345 commited on
Commit
8afad70
·
1 Parent(s): 84c115f

Fix ESRGAN model structure to match checkpoint exact state dict

Browse files
Files changed (1) hide show
  1. app.py +41 -48
app.py CHANGED
@@ -185,53 +185,53 @@ FG_ESRGAN_PATH = CACHE_DIR / "esrgan"
185
  _fg_esrgan_model = None
186
  _fg_esrgan_loading = False
187
 
188
- def _math_ceil(x): return int(x + 0.999) if x > 0 else 0
189
-
190
- def _conv_block(in_nc: int, out_nc: int) -> torch.nn.Sequential:
191
  return torch.nn.Sequential(
192
  torch.nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
193
  torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
194
  )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  class _RRDB(torch.nn.Module):
197
- def __init__(self, nf: int):
198
  super().__init__()
199
- gc = 32
200
- self.RDB1 = _conv_block(nf, gc)
201
- self.conv1_2 = _conv_block(nf + gc, gc)
202
- self.conv1_3 = _conv_block(nf + 2 * gc, gc)
203
- self.conv1_4 = _conv_block(nf + 3 * gc, gc)
204
- self.conv1_5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
205
- self.RDB2 = _conv_block(nf, gc)
206
- self.conv2_2 = _conv_block(nf + gc, gc)
207
- self.conv2_3 = _conv_block(nf + 2 * gc, gc)
208
- self.conv2_4 = _conv_block(nf + 3 * gc, gc)
209
- self.conv2_5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
210
- self.RDB3 = _conv_block(nf, gc)
211
- self.conv3_2 = _conv_block(nf + gc, gc)
212
- self.conv3_3 = _conv_block(nf + 2 * gc, gc)
213
- self.conv3_4 = _conv_block(nf + 3 * gc, gc)
214
- self.conv3_5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
215
-
216
- def _dense_block(self, x, conv2, conv3, conv4, conv5):
217
- x1 = self.RDB1(x)
218
- x2 = conv2(torch.cat((x, x1), 1))
219
- x3 = conv3(torch.cat((x, x1, x2), 1))
220
- x4 = conv4(torch.cat((x, x1, x2, x3), 1))
221
- x5 = conv5(torch.cat((x, x1, x2, x3, x4), 1))
222
- return x5 * 0.2
223
-
224
- def forward(self, x: torch.Tensor) -> torch.Tensor:
225
- out = self._dense_block(x, self.conv1_2, self.conv1_3, self.conv1_4, self.conv1_5)
226
- out = self._dense_block(out + x, self.conv2_2, self.conv2_3, self.conv2_4, self.conv2_5)
227
- out = self._dense_block(out + (out + x), self.conv3_2, self.conv3_3, self.conv3_4, self.conv3_5)
228
  return out * 0.2 + x
229
 
 
 
 
 
 
 
230
  class _RRDBNet(torch.nn.Module):
231
- def __init__(self, nf: int = 64, nb: int = 23):
232
  super().__init__()
233
  self.model = torch.nn.Sequential(
234
- torch.nn.Conv2d(3, nf, kernel_size=3, padding=1),
235
  _SkipBlock(torch.nn.Sequential(
236
  *(_RRDB(nf) for _ in range(nb)),
237
  torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
@@ -244,18 +244,12 @@ class _RRDBNet(torch.nn.Module):
244
  torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
245
  torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
246
  torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
247
- torch.nn.Conv2d(nf, 3, kernel_size=3, padding=1),
248
  )
249
 
250
- def forward(self, x: torch.Tensor) -> torch.Tensor:
251
  return self.model(x)
252
 
253
- class _SkipBlock(torch.nn.Module):
254
- def __init__(self, sub: torch.nn.Module):
255
- super().__init__()
256
- self.sub = sub
257
- def forward(self, x): return x + self.sub(x)
258
-
259
  def _load_fg_esrgan():
260
  global _fg_esrgan_model, _fg_esrgan_loading
261
  if _fg_esrgan_model is not None: return _fg_esrgan_model
@@ -267,9 +261,8 @@ def _load_fg_esrgan():
267
  ckpt = hf_hub_download("philz1337x/upscaler", "4x-UltraSharp.pth", cache_dir=str(FG_ESRGAN_PATH))
268
  logger.info("Loading ESRGAN...")
269
  state = torch.load(ckpt, map_location="cpu", weights_only=True)
270
- in_nc, out_nc, nf, nb = 3, 3, 64, 23 # defaults for 4x-UltraSharp
271
- model = _RRDBNet(nf=nf, nb=nb)
272
- model.load_state_dict(state)
273
  model.eval()
274
  _fg_esrgan_model = model
275
  logger.info("ESRGAN 4X ready (CPU)")
@@ -304,8 +297,8 @@ def upscale_finegrain(image_bytes: bytes, use_sd_refinement: bool = False) -> tu
304
  else:
305
  # Tiled upscale
306
  stride = tile_size - overlap
307
- cols = _math_ceil((w - overlap) / stride) if w > tile_size else 1
308
- rows = _math_ceil((h - overlap) / stride) if h > tile_size else 1
309
  out_arr = np.zeros((h * 4, w * 4, 3), dtype=np.float32)
310
  weight = np.zeros((h * 4, w * 4, 1), dtype=np.float32)
311
  for row in range(rows):
 
185
  _fg_esrgan_model = None
186
  _fg_esrgan_loading = False
187
 
188
+ def _conv_block(in_nc, out_nc):
 
 
189
  return torch.nn.Sequential(
190
  torch.nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
191
  torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
192
  )
193
 
194
+ class _ResidualDenseBlock5C(torch.nn.Module):
195
+ def __init__(self, nf=64, gc=32):
196
+ super().__init__()
197
+ self.conv1 = _conv_block(nf, gc)
198
+ self.conv2 = _conv_block(nf + gc, gc)
199
+ self.conv3 = _conv_block(nf + 2 * gc, gc)
200
+ self.conv4 = _conv_block(nf + 3 * gc, gc)
201
+ self.conv5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
202
+
203
+ def forward(self, x):
204
+ x1 = self.conv1(x)
205
+ x2 = self.conv2(torch.cat((x, x1), 1))
206
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
207
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
208
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
209
+ return x5 * 0.2 + x
210
+
211
  class _RRDB(torch.nn.Module):
212
+ def __init__(self, nf):
213
  super().__init__()
214
+ self.RDB1 = _ResidualDenseBlock5C(nf)
215
+ self.RDB2 = _ResidualDenseBlock5C(nf)
216
+ self.RDB3 = _ResidualDenseBlock5C(nf)
217
+
218
+ def forward(self, x):
219
+ out = self.RDB1(x)
220
+ out = self.RDB2(out)
221
+ out = self.RDB3(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  return out * 0.2 + x
223
 
224
+ class _SkipBlock(torch.nn.Module):
225
+ def __init__(self, sub):
226
+ super().__init__()
227
+ self.sub = sub
228
+ def forward(self, x): return x + self.sub(x)
229
+
230
  class _RRDBNet(torch.nn.Module):
231
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23):
232
  super().__init__()
233
  self.model = torch.nn.Sequential(
234
+ torch.nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
235
  _SkipBlock(torch.nn.Sequential(
236
  *(_RRDB(nf) for _ in range(nb)),
237
  torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
 
244
  torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
245
  torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
246
  torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
247
+ torch.nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
248
  )
249
 
250
+ def forward(self, x):
251
  return self.model(x)
252
 
 
 
 
 
 
 
253
  def _load_fg_esrgan():
254
  global _fg_esrgan_model, _fg_esrgan_loading
255
  if _fg_esrgan_model is not None: return _fg_esrgan_model
 
261
  ckpt = hf_hub_download("philz1337x/upscaler", "4x-UltraSharp.pth", cache_dir=str(FG_ESRGAN_PATH))
262
  logger.info("Loading ESRGAN...")
263
  state = torch.load(ckpt, map_location="cpu", weights_only=True)
264
+ model = _RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23)
265
+ model.load_state_dict(state, strict=False)
 
266
  model.eval()
267
  _fg_esrgan_model = model
268
  logger.info("ESRGAN 4X ready (CPU)")
 
297
  else:
298
  # Tiled upscale
299
  stride = tile_size - overlap
300
+ cols = -(-max(0, w - overlap) // stride) if w > tile_size else 1
301
+ rows = -(-max(0, h - overlap) // stride) if h > tile_size else 1
302
  out_arr = np.zeros((h * 4, w * 4, 3), dtype=np.float32)
303
  weight = np.zeros((h * 4, w * 4, 1), dtype=np.float32)
304
  for row in range(rows):