blure_remover / app.py
itishalogicgo's picture
Boost deblur visibility: multi-pass inference, stronger defaults, user controls
c6e7730
import os
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(ROOT))
import cv2
import gradio as gr
import numpy as np
import torch
from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img
from basicsr.utils.options import parse
DEFAULT_OPT_PATH = ROOT / "options" / "test" / "GoPro" / "NAFNet-width64.yml"
DEFAULT_WEIGHTS_PATH = (
ROOT / "experiments" / "pretrained_models" / "NAFNet-GoPro-width64.pth"
)
def _download_file(url: str, dst: Path) -> None:
import requests
dst.parent.mkdir(parents=True, exist_ok=True)
with requests.get(url, stream=True, timeout=120) as r:
r.raise_for_status()
with open(dst, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
def _ensure_weights(path: Path) -> None:
if path.exists():
return
url = os.getenv("MODEL_URL", "").strip()
if not url:
raise FileNotFoundError(
f"Missing weights at {path}. Provide MODEL_URL env var or add the file."
)
_download_file(url, path)
def _normalize_input(img: np.ndarray) -> np.ndarray:
if img.dtype != np.uint8:
img = img.astype(np.float32)
if img.max() <= 1.0:
img = img * 255.0
img = np.clip(img, 0, 255).astype(np.uint8)
return img
def _img2tensor_rgb(img: np.ndarray) -> torch.Tensor:
img = img.astype(np.float32) / 255.0
return _img2tensor(img, bgr2rgb=False, float32=True)
def _load_model():
_ensure_weights(DEFAULT_WEIGHTS_PATH)
opt = parse(str(DEFAULT_OPT_PATH), is_train=False)
opt["dist"] = False
if not torch.cuda.is_available():
opt["num_gpu"] = 0
# Fix: resolve pretrained weight path to absolute so it works from any CWD
pretrain = opt["path"].get("pretrain_network_g")
if pretrain and not os.path.isabs(pretrain):
opt["path"]["pretrain_network_g"] = str(ROOT / pretrain)
# ---- critical fix ----
# Use plain NAFNet instead of NAFNetLocal.
# NAFNetLocal replaces every AdaptiveAvgPool2d(1) with a custom AvgPool2d
# whose kernel is calibrated for the 256×256 training size. On any real-
# world image the kernel is smaller than the feature map, turning the
# *global* channel attention into weak *local* attention → almost no
# deblurring. Plain NAFNet keeps standard AdaptiveAvgPool2d(1) which
# always pools to 1×1, giving correct global channel attention at every
# resolution. The pretrained weights are 100% compatible (NAFNetLocal
# adds zero learnable parameters on top of NAFNet).
opt["network_g"]["type"] = "NAFNet"
model = create_model(opt)
print(f"[blur_remover] Model loaded on {next(model.net_g.parameters()).device}, "
f"parameters: {sum(p.numel() for p in model.net_g.parameters()):,}")
return model
MODEL = None
def _get_model():
global MODEL
if MODEL is None:
MODEL = _load_model()
return MODEL
def _diff_map(inp: np.ndarray, out: np.ndarray) -> np.ndarray:
if inp.shape != out.shape:
return out
diff = np.abs(out.astype(np.int16) - inp.astype(np.int16)).astype(np.uint8)
# amplify for visibility
diff = np.clip(diff * 3, 0, 255).astype(np.uint8)
return diff
def _apply_strength(inp: np.ndarray, out: np.ndarray, strength: float) -> np.ndarray:
if strength == 1.0:
return out
blended = inp.astype(np.float32) + strength * (out.astype(np.float32) - inp.astype(np.float32))
return np.clip(blended, 0, 255).astype(np.uint8)
def _unsharp_mask(img: np.ndarray, amount: float) -> np.ndarray:
"""Apply unsharp masking for perceptual sharpening."""
if amount <= 0:
return img
sigma = 1.0 + amount * 2.0
blurred = cv2.GaussianBlur(img, (0, 0), sigmaX=sigma, sigmaY=sigma)
sharpened = cv2.addWeighted(img, 1.0 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8)
# ---------------------------------------------------------------------------
# Tile-based inference for large images
# ---------------------------------------------------------------------------
TILE_SIZE = 256
TILE_OVERLAP = 48
def _tile_positions(length: int, tile: int, overlap: int) -> list:
"""Return start positions for overlapping tiles along one axis."""
if length <= tile:
return [0]
stride = tile - overlap
positions = list(range(0, length - tile + 1, stride))
if positions[-1] + tile < length:
positions.append(length - tile)
return sorted(set(positions))
def _run_single_pass(model, lq: torch.Tensor,
tile_size: int = TILE_SIZE,
tile_overlap: int = TILE_OVERLAP) -> torch.Tensor:
"""Run one deblur pass with automatic tiling for large images."""
_, c, h, w = lq.shape
if h <= tile_size and w <= tile_size:
model.feed_data(data={"lq": lq})
model.test()
return model.get_current_visuals()["result"]
rows = _tile_positions(h, tile_size, tile_overlap)
cols = _tile_positions(w, tile_size, tile_overlap)
out_acc = torch.zeros(1, c, h, w)
count = torch.zeros(1, 1, h, w)
for y in rows:
for x in cols:
y_end = min(y + tile_size, h)
x_end = min(x + tile_size, w)
tile = lq[:, :, y:y_end, x:x_end]
model.feed_data(data={"lq": tile})
model.test()
tile_out = model.get_current_visuals()["result"]
out_acc[:, :, y:y_end, x:x_end] += tile_out
count[:, :, y:y_end, x:x_end] += 1.0
return out_acc / count.clamp(min=1.0)
def _run_inference(model, lq: torch.Tensor, passes: int = 1) -> torch.Tensor:
"""Run deblur inference with multiple passes for stronger effect."""
current = lq
for _ in range(passes):
current = _run_single_pass(model, current)
return current
def deblur(image: np.ndarray, strength: float, sharpen: float, passes: int):
if image is None:
raise gr.Error("Please upload an image.")
try:
model = _get_model()
except FileNotFoundError as exc:
raise gr.Error(
"Model weights not found. Set MODEL_URL in the Space settings "
"or add the weight file at experiments/pretrained_models/NAFNet-GoPro-width64.pth."
) from exc
except Exception as exc:
raise gr.Error(f"Failed to load model: {exc}") from exc
img_input = _normalize_input(image)
if img_input.ndim == 2:
img_input = cv2.cvtColor(img_input, cv2.COLOR_GRAY2RGB)
if img_input.shape[2] == 4:
img_input = cv2.cvtColor(img_input, cv2.COLOR_RGBA2RGB)
inp = _img2tensor_rgb(img_input)
try:
result = _run_inference(model, inp.unsqueeze(dim=0), passes=int(passes))
sr_img = tensor2img([result], rgb2bgr=False)
except RuntimeError as exc:
if "out of memory" in str(exc).lower():
raise gr.Error(
"Out of memory. Try uploading a smaller image or reducing passes."
) from exc
raise gr.Error(f"Inference failed: {exc}") from exc
sr_img = _apply_strength(img_input, sr_img, strength)
sr_img = _unsharp_mask(sr_img, sharpen)
diff = _diff_map(img_input, sr_img)
return sr_img, diff
def build_ui():
with gr.Blocks(title="NAFNet Deblur") as demo:
gr.Markdown(
"# NAFNet Deblur\n"
"Upload a blurry image and get a deblurred result.\n\n"
"**Tips:** Increase **Strength** to amplify the effect. "
"Raise **Sharpen** for extra crispness. "
"Use **Passes** > 1 for heavily blurred images."
)
with gr.Row():
inp = gr.Image(label="Input (Blurry)", type="numpy")
out = gr.Image(label="Output (Deblurred)", type="numpy")
diff = gr.Image(label="Diff (x3)", type="numpy")
with gr.Row():
strength = gr.Slider(
0.5, 5.0, value=2.0, step=0.1,
label="Strength (amplify deblur effect)")
sharpen = gr.Slider(
0.0, 2.0, value=0.5, step=0.05,
label="Sharpen (post-processing)")
passes = gr.Slider(
1, 5, value=2, step=1,
label="Passes (run model N times)")
btn = gr.Button("Deblur", variant="primary")
btn.click(fn=deblur, inputs=[inp, strength, sharpen, passes], outputs=[out, diff])
return demo
app = build_ui()
if __name__ == "__main__":
app.launch()