Commit
·
3edaa28
1
Parent(s):
b5a0978
Fix weak deblurring: add tile-based inference for large images
Browse filesNAFNetLocal's channel attention uses local pooling calibrated for 256x256.
For larger images, the attention kernel covers only a fraction of each
feature map, making the residual near-zero (output = input).
Fix: process large images in overlapping 256x256 tiles so each tile gets
proper global channel attention matching the training behavior. Small
images (<= 256px) still use a single forward pass.
app.py
CHANGED
|
@@ -107,6 +107,69 @@ def _unsharp_mask(img: np.ndarray, amount: float) -> np.ndarray:
|
|
| 107 |
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
| 108 |
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def deblur(image: np.ndarray, strength: float, sharpen: float):
|
| 111 |
if image is None:
|
| 112 |
raise gr.Error("Please upload an image.")
|
|
@@ -130,21 +193,12 @@ def deblur(image: np.ndarray, strength: float, sharpen: float):
|
|
| 130 |
inp = _img2tensor_rgb(img_input)
|
| 131 |
|
| 132 |
try:
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
model.grids()
|
| 136 |
-
|
| 137 |
-
model.test()
|
| 138 |
-
|
| 139 |
-
if model.opt["val"].get("grids", False):
|
| 140 |
-
model.grids_inverse()
|
| 141 |
-
|
| 142 |
-
visuals = model.get_current_visuals()
|
| 143 |
-
sr_img = tensor2img([visuals["result"]], rgb2bgr=False)
|
| 144 |
except RuntimeError as exc:
|
| 145 |
if "out of memory" in str(exc).lower():
|
| 146 |
raise gr.Error(
|
| 147 |
-
"Out of
|
| 148 |
) from exc
|
| 149 |
raise gr.Error(f"Inference failed: {exc}") from exc
|
| 150 |
|
|
|
|
| 107 |
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
| 108 |
|
| 109 |
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# Tile-based inference — critical for proper deblurring on large images.
|
| 112 |
+
#
|
| 113 |
+
# NAFNetLocal replaces AdaptiveAvgPool2d(1) with a fixed-kernel AvgPool2d
|
| 114 |
+
# calibrated for the 256×256 training resolution. For images larger than
|
| 115 |
+
# ~256 px the kernel becomes *local* instead of *global*, which cripples the
|
| 116 |
+
# channel attention and makes the residual almost zero (output ≈ input).
|
| 117 |
+
#
|
| 118 |
+
# By processing in 256×256 tiles with overlap we guarantee that every tile
|
| 119 |
+
# goes through the network with global channel attention — matching the
|
| 120 |
+
# training behaviour and producing strong deblurring.
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
TILE_SIZE = 256
|
| 123 |
+
TILE_OVERLAP = 48
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _tile_positions(length: int, tile: int, overlap: int) -> list:
|
| 127 |
+
"""Return start positions for overlapping tiles along one axis."""
|
| 128 |
+
if length <= tile:
|
| 129 |
+
return [0]
|
| 130 |
+
stride = tile - overlap
|
| 131 |
+
positions = list(range(0, length - tile + 1, stride))
|
| 132 |
+
# make sure the last tile reaches the edge
|
| 133 |
+
if positions[-1] + tile < length:
|
| 134 |
+
positions.append(length - tile)
|
| 135 |
+
return sorted(set(positions))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _run_inference(model, lq: torch.Tensor,
|
| 139 |
+
tile_size: int = TILE_SIZE,
|
| 140 |
+
tile_overlap: int = TILE_OVERLAP) -> torch.Tensor:
|
| 141 |
+
"""Run deblur inference — with automatic tiling for large images."""
|
| 142 |
+
_, c, h, w = lq.shape
|
| 143 |
+
|
| 144 |
+
# Small image → single forward pass (attention is already global)
|
| 145 |
+
if h <= tile_size and w <= tile_size:
|
| 146 |
+
model.feed_data(data={"lq": lq})
|
| 147 |
+
model.test()
|
| 148 |
+
return model.get_current_visuals()["result"]
|
| 149 |
+
|
| 150 |
+
# Large image → tile-based inference
|
| 151 |
+
rows = _tile_positions(h, tile_size, tile_overlap)
|
| 152 |
+
cols = _tile_positions(w, tile_size, tile_overlap)
|
| 153 |
+
|
| 154 |
+
out_acc = torch.zeros(1, c, h, w)
|
| 155 |
+
count = torch.zeros(1, 1, h, w)
|
| 156 |
+
|
| 157 |
+
for y in rows:
|
| 158 |
+
for x in cols:
|
| 159 |
+
y_end = min(y + tile_size, h)
|
| 160 |
+
x_end = min(x + tile_size, w)
|
| 161 |
+
tile = lq[:, :, y:y_end, x:x_end]
|
| 162 |
+
|
| 163 |
+
model.feed_data(data={"lq": tile})
|
| 164 |
+
model.test()
|
| 165 |
+
tile_out = model.get_current_visuals()["result"]
|
| 166 |
+
|
| 167 |
+
out_acc[:, :, y:y_end, x:x_end] += tile_out
|
| 168 |
+
count[:, :, y:y_end, x:x_end] += 1.0
|
| 169 |
+
|
| 170 |
+
return out_acc / count.clamp(min=1.0)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
def deblur(image: np.ndarray, strength: float, sharpen: float):
|
| 174 |
if image is None:
|
| 175 |
raise gr.Error("Please upload an image.")
|
|
|
|
| 193 |
inp = _img2tensor_rgb(img_input)
|
| 194 |
|
| 195 |
try:
|
| 196 |
+
result = _run_inference(model, inp.unsqueeze(dim=0))
|
| 197 |
+
sr_img = tensor2img([result], rgb2bgr=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
except RuntimeError as exc:
|
| 199 |
if "out of memory" in str(exc).lower():
|
| 200 |
raise gr.Error(
|
| 201 |
+
"Out of memory. Try uploading a smaller image."
|
| 202 |
) from exc
|
| 203 |
raise gr.Error(f"Inference failed: {exc}") from exc
|
| 204 |
|