import os import sys from pathlib import Path ROOT = Path(__file__).resolve().parent sys.path.insert(0, str(ROOT)) import cv2 import gradio as gr import numpy as np import torch from basicsr.models import create_model from basicsr.utils import img2tensor as _img2tensor, tensor2img from basicsr.utils.options import parse DEFAULT_OPT_PATH = ROOT / "options" / "test" / "GoPro" / "NAFNet-width64.yml" DEFAULT_WEIGHTS_PATH = ( ROOT / "experiments" / "pretrained_models" / "NAFNet-GoPro-width64.pth" ) def _download_file(url: str, dst: Path) -> None: import requests dst.parent.mkdir(parents=True, exist_ok=True) with requests.get(url, stream=True, timeout=120) as r: r.raise_for_status() with open(dst, "wb") as f: for chunk in r.iter_content(chunk_size=1024 * 1024): if chunk: f.write(chunk) def _ensure_weights(path: Path) -> None: if path.exists(): return url = os.getenv("MODEL_URL", "").strip() if not url: raise FileNotFoundError( f"Missing weights at {path}. Provide MODEL_URL env var or add the file." ) _download_file(url, path) def _normalize_input(img: np.ndarray) -> np.ndarray: if img.dtype != np.uint8: img = img.astype(np.float32) if img.max() <= 1.0: img = img * 255.0 img = np.clip(img, 0, 255).astype(np.uint8) return img def _img2tensor_rgb(img: np.ndarray) -> torch.Tensor: img = img.astype(np.float32) / 255.0 return _img2tensor(img, bgr2rgb=False, float32=True) def _load_model(): _ensure_weights(DEFAULT_WEIGHTS_PATH) opt = parse(str(DEFAULT_OPT_PATH), is_train=False) opt["dist"] = False if not torch.cuda.is_available(): opt["num_gpu"] = 0 # Fix: resolve pretrained weight path to absolute so it works from any CWD pretrain = opt["path"].get("pretrain_network_g") if pretrain and not os.path.isabs(pretrain): opt["path"]["pretrain_network_g"] = str(ROOT / pretrain) # ---- critical fix ---- # Use plain NAFNet instead of NAFNetLocal. # NAFNetLocal replaces every AdaptiveAvgPool2d(1) with a custom AvgPool2d # whose kernel is calibrated for the 256×256 training size. On any real- # world image the kernel is smaller than the feature map, turning the # *global* channel attention into weak *local* attention → almost no # deblurring. Plain NAFNet keeps standard AdaptiveAvgPool2d(1) which # always pools to 1×1, giving correct global channel attention at every # resolution. The pretrained weights are 100% compatible (NAFNetLocal # adds zero learnable parameters on top of NAFNet). opt["network_g"]["type"] = "NAFNet" model = create_model(opt) print(f"[blur_remover] Model loaded on {next(model.net_g.parameters()).device}, " f"parameters: {sum(p.numel() for p in model.net_g.parameters()):,}") return model MODEL = None def _get_model(): global MODEL if MODEL is None: MODEL = _load_model() return MODEL def _diff_map(inp: np.ndarray, out: np.ndarray) -> np.ndarray: if inp.shape != out.shape: return out diff = np.abs(out.astype(np.int16) - inp.astype(np.int16)).astype(np.uint8) # amplify for visibility diff = np.clip(diff * 3, 0, 255).astype(np.uint8) return diff def _apply_strength(inp: np.ndarray, out: np.ndarray, strength: float) -> np.ndarray: if strength == 1.0: return out blended = inp.astype(np.float32) + strength * (out.astype(np.float32) - inp.astype(np.float32)) return np.clip(blended, 0, 255).astype(np.uint8) def _unsharp_mask(img: np.ndarray, amount: float) -> np.ndarray: """Apply unsharp masking for perceptual sharpening.""" if amount <= 0: return img sigma = 1.0 + amount * 2.0 blurred = cv2.GaussianBlur(img, (0, 0), sigmaX=sigma, sigmaY=sigma) sharpened = cv2.addWeighted(img, 1.0 + amount, blurred, -amount, 0) return np.clip(sharpened, 0, 255).astype(np.uint8) # --------------------------------------------------------------------------- # Tile-based inference for large images # --------------------------------------------------------------------------- TILE_SIZE = 256 TILE_OVERLAP = 48 def _tile_positions(length: int, tile: int, overlap: int) -> list: """Return start positions for overlapping tiles along one axis.""" if length <= tile: return [0] stride = tile - overlap positions = list(range(0, length - tile + 1, stride)) if positions[-1] + tile < length: positions.append(length - tile) return sorted(set(positions)) def _run_single_pass(model, lq: torch.Tensor, tile_size: int = TILE_SIZE, tile_overlap: int = TILE_OVERLAP) -> torch.Tensor: """Run one deblur pass with automatic tiling for large images.""" _, c, h, w = lq.shape if h <= tile_size and w <= tile_size: model.feed_data(data={"lq": lq}) model.test() return model.get_current_visuals()["result"] rows = _tile_positions(h, tile_size, tile_overlap) cols = _tile_positions(w, tile_size, tile_overlap) out_acc = torch.zeros(1, c, h, w) count = torch.zeros(1, 1, h, w) for y in rows: for x in cols: y_end = min(y + tile_size, h) x_end = min(x + tile_size, w) tile = lq[:, :, y:y_end, x:x_end] model.feed_data(data={"lq": tile}) model.test() tile_out = model.get_current_visuals()["result"] out_acc[:, :, y:y_end, x:x_end] += tile_out count[:, :, y:y_end, x:x_end] += 1.0 return out_acc / count.clamp(min=1.0) def _run_inference(model, lq: torch.Tensor, passes: int = 1) -> torch.Tensor: """Run deblur inference with multiple passes for stronger effect.""" current = lq for _ in range(passes): current = _run_single_pass(model, current) return current def deblur(image: np.ndarray, strength: float, sharpen: float, passes: int): if image is None: raise gr.Error("Please upload an image.") try: model = _get_model() except FileNotFoundError as exc: raise gr.Error( "Model weights not found. Set MODEL_URL in the Space settings " "or add the weight file at experiments/pretrained_models/NAFNet-GoPro-width64.pth." ) from exc except Exception as exc: raise gr.Error(f"Failed to load model: {exc}") from exc img_input = _normalize_input(image) if img_input.ndim == 2: img_input = cv2.cvtColor(img_input, cv2.COLOR_GRAY2RGB) if img_input.shape[2] == 4: img_input = cv2.cvtColor(img_input, cv2.COLOR_RGBA2RGB) inp = _img2tensor_rgb(img_input) try: result = _run_inference(model, inp.unsqueeze(dim=0), passes=int(passes)) sr_img = tensor2img([result], rgb2bgr=False) except RuntimeError as exc: if "out of memory" in str(exc).lower(): raise gr.Error( "Out of memory. Try uploading a smaller image or reducing passes." ) from exc raise gr.Error(f"Inference failed: {exc}") from exc sr_img = _apply_strength(img_input, sr_img, strength) sr_img = _unsharp_mask(sr_img, sharpen) diff = _diff_map(img_input, sr_img) return sr_img, diff def build_ui(): with gr.Blocks(title="NAFNet Deblur") as demo: gr.Markdown( "# NAFNet Deblur\n" "Upload a blurry image and get a deblurred result.\n\n" "**Tips:** Increase **Strength** to amplify the effect. " "Raise **Sharpen** for extra crispness. " "Use **Passes** > 1 for heavily blurred images." ) with gr.Row(): inp = gr.Image(label="Input (Blurry)", type="numpy") out = gr.Image(label="Output (Deblurred)", type="numpy") diff = gr.Image(label="Diff (x3)", type="numpy") with gr.Row(): strength = gr.Slider( 0.5, 5.0, value=2.0, step=0.1, label="Strength (amplify deblur effect)") sharpen = gr.Slider( 0.0, 2.0, value=0.5, step=0.05, label="Sharpen (post-processing)") passes = gr.Slider( 1, 5, value=2, step=1, label="Passes (run model N times)") btn = gr.Button("Deblur", variant="primary") btn.click(fn=deblur, inputs=[inp, strength, sharpen, passes], outputs=[out, diff]) return demo app = build_ui() if __name__ == "__main__": app.launch()