Spaces:
Sleeping
Sleeping
Commit ·
8afad70
1
Parent(s): 84c115f
Fix ESRGAN model structure to match checkpoint exact state dict
Browse files
app.py
CHANGED
|
@@ -185,53 +185,53 @@ FG_ESRGAN_PATH = CACHE_DIR / "esrgan"
|
|
| 185 |
_fg_esrgan_model = None
|
| 186 |
_fg_esrgan_loading = False
|
| 187 |
|
| 188 |
-
def
|
| 189 |
-
|
| 190 |
-
def _conv_block(in_nc: int, out_nc: int) -> torch.nn.Sequential:
|
| 191 |
return torch.nn.Sequential(
|
| 192 |
torch.nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
|
| 193 |
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 194 |
)
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
class _RRDB(torch.nn.Module):
|
| 197 |
-
def __init__(self, nf
|
| 198 |
super().__init__()
|
| 199 |
-
|
| 200 |
-
self.
|
| 201 |
-
self.
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
self.conv2_3 = _conv_block(nf + 2 * gc, gc)
|
| 208 |
-
self.conv2_4 = _conv_block(nf + 3 * gc, gc)
|
| 209 |
-
self.conv2_5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
|
| 210 |
-
self.RDB3 = _conv_block(nf, gc)
|
| 211 |
-
self.conv3_2 = _conv_block(nf + gc, gc)
|
| 212 |
-
self.conv3_3 = _conv_block(nf + 2 * gc, gc)
|
| 213 |
-
self.conv3_4 = _conv_block(nf + 3 * gc, gc)
|
| 214 |
-
self.conv3_5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
|
| 215 |
-
|
| 216 |
-
def _dense_block(self, x, conv2, conv3, conv4, conv5):
|
| 217 |
-
x1 = self.RDB1(x)
|
| 218 |
-
x2 = conv2(torch.cat((x, x1), 1))
|
| 219 |
-
x3 = conv3(torch.cat((x, x1, x2), 1))
|
| 220 |
-
x4 = conv4(torch.cat((x, x1, x2, x3), 1))
|
| 221 |
-
x5 = conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 222 |
-
return x5 * 0.2
|
| 223 |
-
|
| 224 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 225 |
-
out = self._dense_block(x, self.conv1_2, self.conv1_3, self.conv1_4, self.conv1_5)
|
| 226 |
-
out = self._dense_block(out + x, self.conv2_2, self.conv2_3, self.conv2_4, self.conv2_5)
|
| 227 |
-
out = self._dense_block(out + (out + x), self.conv3_2, self.conv3_3, self.conv3_4, self.conv3_5)
|
| 228 |
return out * 0.2 + x
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
class _RRDBNet(torch.nn.Module):
|
| 231 |
-
def __init__(self,
|
| 232 |
super().__init__()
|
| 233 |
self.model = torch.nn.Sequential(
|
| 234 |
-
torch.nn.Conv2d(
|
| 235 |
_SkipBlock(torch.nn.Sequential(
|
| 236 |
*(_RRDB(nf) for _ in range(nb)),
|
| 237 |
torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
|
@@ -244,18 +244,12 @@ class _RRDBNet(torch.nn.Module):
|
|
| 244 |
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 245 |
torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
| 246 |
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 247 |
-
torch.nn.Conv2d(nf,
|
| 248 |
)
|
| 249 |
|
| 250 |
-
def forward(self, x
|
| 251 |
return self.model(x)
|
| 252 |
|
| 253 |
-
class _SkipBlock(torch.nn.Module):
|
| 254 |
-
def __init__(self, sub: torch.nn.Module):
|
| 255 |
-
super().__init__()
|
| 256 |
-
self.sub = sub
|
| 257 |
-
def forward(self, x): return x + self.sub(x)
|
| 258 |
-
|
| 259 |
def _load_fg_esrgan():
|
| 260 |
global _fg_esrgan_model, _fg_esrgan_loading
|
| 261 |
if _fg_esrgan_model is not None: return _fg_esrgan_model
|
|
@@ -267,9 +261,8 @@ def _load_fg_esrgan():
|
|
| 267 |
ckpt = hf_hub_download("philz1337x/upscaler", "4x-UltraSharp.pth", cache_dir=str(FG_ESRGAN_PATH))
|
| 268 |
logger.info("Loading ESRGAN...")
|
| 269 |
state = torch.load(ckpt, map_location="cpu", weights_only=True)
|
| 270 |
-
|
| 271 |
-
model
|
| 272 |
-
model.load_state_dict(state)
|
| 273 |
model.eval()
|
| 274 |
_fg_esrgan_model = model
|
| 275 |
logger.info("ESRGAN 4X ready (CPU)")
|
|
@@ -304,8 +297,8 @@ def upscale_finegrain(image_bytes: bytes, use_sd_refinement: bool = False) -> tu
|
|
| 304 |
else:
|
| 305 |
# Tiled upscale
|
| 306 |
stride = tile_size - overlap
|
| 307 |
-
cols =
|
| 308 |
-
rows =
|
| 309 |
out_arr = np.zeros((h * 4, w * 4, 3), dtype=np.float32)
|
| 310 |
weight = np.zeros((h * 4, w * 4, 1), dtype=np.float32)
|
| 311 |
for row in range(rows):
|
|
|
|
| 185 |
_fg_esrgan_model = None
|
| 186 |
_fg_esrgan_loading = False
|
| 187 |
|
| 188 |
+
def _conv_block(in_nc, out_nc):
|
|
|
|
|
|
|
| 189 |
return torch.nn.Sequential(
|
| 190 |
torch.nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
|
| 191 |
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 192 |
)
|
| 193 |
|
| 194 |
+
class _ResidualDenseBlock5C(torch.nn.Module):
|
| 195 |
+
def __init__(self, nf=64, gc=32):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.conv1 = _conv_block(nf, gc)
|
| 198 |
+
self.conv2 = _conv_block(nf + gc, gc)
|
| 199 |
+
self.conv3 = _conv_block(nf + 2 * gc, gc)
|
| 200 |
+
self.conv4 = _conv_block(nf + 3 * gc, gc)
|
| 201 |
+
self.conv5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
x1 = self.conv1(x)
|
| 205 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
| 206 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
| 207 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
| 208 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 209 |
+
return x5 * 0.2 + x
|
| 210 |
+
|
| 211 |
class _RRDB(torch.nn.Module):
|
| 212 |
+
def __init__(self, nf):
|
| 213 |
super().__init__()
|
| 214 |
+
self.RDB1 = _ResidualDenseBlock5C(nf)
|
| 215 |
+
self.RDB2 = _ResidualDenseBlock5C(nf)
|
| 216 |
+
self.RDB3 = _ResidualDenseBlock5C(nf)
|
| 217 |
+
|
| 218 |
+
def forward(self, x):
|
| 219 |
+
out = self.RDB1(x)
|
| 220 |
+
out = self.RDB2(out)
|
| 221 |
+
out = self.RDB3(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
return out * 0.2 + x
|
| 223 |
|
| 224 |
+
class _SkipBlock(torch.nn.Module):
|
| 225 |
+
def __init__(self, sub):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.sub = sub
|
| 228 |
+
def forward(self, x): return x + self.sub(x)
|
| 229 |
+
|
| 230 |
class _RRDBNet(torch.nn.Module):
|
| 231 |
+
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23):
|
| 232 |
super().__init__()
|
| 233 |
self.model = torch.nn.Sequential(
|
| 234 |
+
torch.nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
|
| 235 |
_SkipBlock(torch.nn.Sequential(
|
| 236 |
*(_RRDB(nf) for _ in range(nb)),
|
| 237 |
torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
|
|
|
| 244 |
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 245 |
torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1),
|
| 246 |
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
| 247 |
+
torch.nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
|
| 248 |
)
|
| 249 |
|
| 250 |
+
def forward(self, x):
|
| 251 |
return self.model(x)
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
def _load_fg_esrgan():
|
| 254 |
global _fg_esrgan_model, _fg_esrgan_loading
|
| 255 |
if _fg_esrgan_model is not None: return _fg_esrgan_model
|
|
|
|
| 261 |
ckpt = hf_hub_download("philz1337x/upscaler", "4x-UltraSharp.pth", cache_dir=str(FG_ESRGAN_PATH))
|
| 262 |
logger.info("Loading ESRGAN...")
|
| 263 |
state = torch.load(ckpt, map_location="cpu", weights_only=True)
|
| 264 |
+
model = _RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23)
|
| 265 |
+
model.load_state_dict(state, strict=False)
|
|
|
|
| 266 |
model.eval()
|
| 267 |
_fg_esrgan_model = model
|
| 268 |
logger.info("ESRGAN 4X ready (CPU)")
|
|
|
|
| 297 |
else:
|
| 298 |
# Tiled upscale
|
| 299 |
stride = tile_size - overlap
|
| 300 |
+
cols = -(-max(0, w - overlap) // stride) if w > tile_size else 1
|
| 301 |
+
rows = -(-max(0, h - overlap) // stride) if h > tile_size else 1
|
| 302 |
out_arr = np.zeros((h * 4, w * 4, 3), dtype=np.float32)
|
| 303 |
weight = np.zeros((h * 4, w * 4, 1), dtype=np.float32)
|
| 304 |
for row in range(rows):
|