itishalogicgo commited on
Commit
3edaa28
·
1 Parent(s): b5a0978

Fix weak deblurring: add tile-based inference for large images

Browse files

NAFNetLocal's channel attention uses local pooling calibrated for 256x256.
For larger images, the attention kernel covers only a fraction of each
feature map, making the residual near-zero (output = input).

Fix: process large images in overlapping 256x256 tiles so each tile gets
proper global channel attention matching the training behavior. Small
images (<= 256px) still use a single forward pass.

Files changed (1) hide show
  1. app.py +66 -12
app.py CHANGED
@@ -107,6 +107,69 @@ def _unsharp_mask(img: np.ndarray, amount: float) -> np.ndarray:
107
  return np.clip(sharpened, 0, 255).astype(np.uint8)
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def deblur(image: np.ndarray, strength: float, sharpen: float):
111
  if image is None:
112
  raise gr.Error("Please upload an image.")
@@ -130,21 +193,12 @@ def deblur(image: np.ndarray, strength: float, sharpen: float):
130
  inp = _img2tensor_rgb(img_input)
131
 
132
  try:
133
- model.feed_data(data={"lq": inp.unsqueeze(dim=0)})
134
- if model.opt["val"].get("grids", False):
135
- model.grids()
136
-
137
- model.test()
138
-
139
- if model.opt["val"].get("grids", False):
140
- model.grids_inverse()
141
-
142
- visuals = model.get_current_visuals()
143
- sr_img = tensor2img([visuals["result"]], rgb2bgr=False)
144
  except RuntimeError as exc:
145
  if "out of memory" in str(exc).lower():
146
  raise gr.Error(
147
- "Out of GPU memory. Try a smaller image or enable grid-based tiling."
148
  ) from exc
149
  raise gr.Error(f"Inference failed: {exc}") from exc
150
 
 
107
  return np.clip(sharpened, 0, 255).astype(np.uint8)
108
 
109
 
110
+ # ---------------------------------------------------------------------------
111
+ # Tile-based inference — critical for proper deblurring on large images.
112
+ #
113
+ # NAFNetLocal replaces AdaptiveAvgPool2d(1) with a fixed-kernel AvgPool2d
114
+ # calibrated for the 256×256 training resolution. For images larger than
115
+ # ~256 px the kernel becomes *local* instead of *global*, which cripples the
116
+ # channel attention and makes the residual almost zero (output ≈ input).
117
+ #
118
+ # By processing in 256×256 tiles with overlap we guarantee that every tile
119
+ # goes through the network with global channel attention — matching the
120
+ # training behaviour and producing strong deblurring.
121
+ # ---------------------------------------------------------------------------
122
+ TILE_SIZE = 256
123
+ TILE_OVERLAP = 48
124
+
125
+
126
+ def _tile_positions(length: int, tile: int, overlap: int) -> list:
127
+ """Return start positions for overlapping tiles along one axis."""
128
+ if length <= tile:
129
+ return [0]
130
+ stride = tile - overlap
131
+ positions = list(range(0, length - tile + 1, stride))
132
+ # make sure the last tile reaches the edge
133
+ if positions[-1] + tile < length:
134
+ positions.append(length - tile)
135
+ return sorted(set(positions))
136
+
137
+
138
+ def _run_inference(model, lq: torch.Tensor,
139
+ tile_size: int = TILE_SIZE,
140
+ tile_overlap: int = TILE_OVERLAP) -> torch.Tensor:
141
+ """Run deblur inference — with automatic tiling for large images."""
142
+ _, c, h, w = lq.shape
143
+
144
+ # Small image → single forward pass (attention is already global)
145
+ if h <= tile_size and w <= tile_size:
146
+ model.feed_data(data={"lq": lq})
147
+ model.test()
148
+ return model.get_current_visuals()["result"]
149
+
150
+ # Large image → tile-based inference
151
+ rows = _tile_positions(h, tile_size, tile_overlap)
152
+ cols = _tile_positions(w, tile_size, tile_overlap)
153
+
154
+ out_acc = torch.zeros(1, c, h, w)
155
+ count = torch.zeros(1, 1, h, w)
156
+
157
+ for y in rows:
158
+ for x in cols:
159
+ y_end = min(y + tile_size, h)
160
+ x_end = min(x + tile_size, w)
161
+ tile = lq[:, :, y:y_end, x:x_end]
162
+
163
+ model.feed_data(data={"lq": tile})
164
+ model.test()
165
+ tile_out = model.get_current_visuals()["result"]
166
+
167
+ out_acc[:, :, y:y_end, x:x_end] += tile_out
168
+ count[:, :, y:y_end, x:x_end] += 1.0
169
+
170
+ return out_acc / count.clamp(min=1.0)
171
+
172
+
173
  def deblur(image: np.ndarray, strength: float, sharpen: float):
174
  if image is None:
175
  raise gr.Error("Please upload an image.")
 
193
  inp = _img2tensor_rgb(img_input)
194
 
195
  try:
196
+ result = _run_inference(model, inp.unsqueeze(dim=0))
197
+ sr_img = tensor2img([result], rgb2bgr=False)
 
 
 
 
 
 
 
 
 
198
  except RuntimeError as exc:
199
  if "out of memory" in str(exc).lower():
200
  raise gr.Error(
201
+ "Out of memory. Try uploading a smaller image."
202
  ) from exc
203
  raise gr.Error(f"Inference failed: {exc}") from exc
204