|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
ROOT = Path(__file__).resolve().parent |
|
|
sys.path.insert(0, str(ROOT)) |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from basicsr.models import create_model |
|
|
from basicsr.utils import img2tensor as _img2tensor, tensor2img |
|
|
from basicsr.utils.options import parse |
|
|
|
|
|
DEFAULT_OPT_PATH = ROOT / "options" / "test" / "GoPro" / "NAFNet-width64.yml" |
|
|
DEFAULT_WEIGHTS_PATH = ( |
|
|
ROOT / "experiments" / "pretrained_models" / "NAFNet-GoPro-width64.pth" |
|
|
) |
|
|
|
|
|
|
|
|
def _download_file(url: str, dst: Path) -> None: |
|
|
import requests |
|
|
|
|
|
dst.parent.mkdir(parents=True, exist_ok=True) |
|
|
with requests.get(url, stream=True, timeout=120) as r: |
|
|
r.raise_for_status() |
|
|
with open(dst, "wb") as f: |
|
|
for chunk in r.iter_content(chunk_size=1024 * 1024): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
|
|
|
|
|
|
def _ensure_weights(path: Path) -> None: |
|
|
if path.exists(): |
|
|
return |
|
|
url = os.getenv("MODEL_URL", "").strip() |
|
|
if not url: |
|
|
raise FileNotFoundError( |
|
|
f"Missing weights at {path}. Provide MODEL_URL env var or add the file." |
|
|
) |
|
|
_download_file(url, path) |
|
|
|
|
|
|
|
|
def _normalize_input(img: np.ndarray) -> np.ndarray: |
|
|
if img.dtype != np.uint8: |
|
|
img = img.astype(np.float32) |
|
|
if img.max() <= 1.0: |
|
|
img = img * 255.0 |
|
|
img = np.clip(img, 0, 255).astype(np.uint8) |
|
|
return img |
|
|
|
|
|
|
|
|
def _img2tensor_rgb(img: np.ndarray) -> torch.Tensor: |
|
|
img = img.astype(np.float32) / 255.0 |
|
|
return _img2tensor(img, bgr2rgb=False, float32=True) |
|
|
|
|
|
|
|
|
def _load_model(): |
|
|
_ensure_weights(DEFAULT_WEIGHTS_PATH) |
|
|
opt = parse(str(DEFAULT_OPT_PATH), is_train=False) |
|
|
opt["dist"] = False |
|
|
if not torch.cuda.is_available(): |
|
|
opt["num_gpu"] = 0 |
|
|
|
|
|
|
|
|
pretrain = opt["path"].get("pretrain_network_g") |
|
|
if pretrain and not os.path.isabs(pretrain): |
|
|
opt["path"]["pretrain_network_g"] = str(ROOT / pretrain) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
opt["network_g"]["type"] = "NAFNet" |
|
|
|
|
|
model = create_model(opt) |
|
|
print(f"[blur_remover] Model loaded on {next(model.net_g.parameters()).device}, " |
|
|
f"parameters: {sum(p.numel() for p in model.net_g.parameters()):,}") |
|
|
return model |
|
|
|
|
|
|
|
|
MODEL = None |
|
|
|
|
|
|
|
|
def _get_model(): |
|
|
global MODEL |
|
|
if MODEL is None: |
|
|
MODEL = _load_model() |
|
|
return MODEL |
|
|
|
|
|
|
|
|
def _diff_map(inp: np.ndarray, out: np.ndarray) -> np.ndarray: |
|
|
if inp.shape != out.shape: |
|
|
return out |
|
|
diff = np.abs(out.astype(np.int16) - inp.astype(np.int16)).astype(np.uint8) |
|
|
|
|
|
diff = np.clip(diff * 3, 0, 255).astype(np.uint8) |
|
|
return diff |
|
|
|
|
|
|
|
|
def _apply_strength(inp: np.ndarray, out: np.ndarray, strength: float) -> np.ndarray: |
|
|
if strength == 1.0: |
|
|
return out |
|
|
blended = inp.astype(np.float32) + strength * (out.astype(np.float32) - inp.astype(np.float32)) |
|
|
return np.clip(blended, 0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
def _unsharp_mask(img: np.ndarray, amount: float) -> np.ndarray: |
|
|
"""Apply unsharp masking for perceptual sharpening.""" |
|
|
if amount <= 0: |
|
|
return img |
|
|
sigma = 1.0 + amount * 2.0 |
|
|
blurred = cv2.GaussianBlur(img, (0, 0), sigmaX=sigma, sigmaY=sigma) |
|
|
sharpened = cv2.addWeighted(img, 1.0 + amount, blurred, -amount, 0) |
|
|
return np.clip(sharpened, 0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TILE_SIZE = 256 |
|
|
TILE_OVERLAP = 48 |
|
|
|
|
|
|
|
|
def _tile_positions(length: int, tile: int, overlap: int) -> list: |
|
|
"""Return start positions for overlapping tiles along one axis.""" |
|
|
if length <= tile: |
|
|
return [0] |
|
|
stride = tile - overlap |
|
|
positions = list(range(0, length - tile + 1, stride)) |
|
|
if positions[-1] + tile < length: |
|
|
positions.append(length - tile) |
|
|
return sorted(set(positions)) |
|
|
|
|
|
|
|
|
def _run_single_pass(model, lq: torch.Tensor, |
|
|
tile_size: int = TILE_SIZE, |
|
|
tile_overlap: int = TILE_OVERLAP) -> torch.Tensor: |
|
|
"""Run one deblur pass with automatic tiling for large images.""" |
|
|
_, c, h, w = lq.shape |
|
|
|
|
|
if h <= tile_size and w <= tile_size: |
|
|
model.feed_data(data={"lq": lq}) |
|
|
model.test() |
|
|
return model.get_current_visuals()["result"] |
|
|
|
|
|
rows = _tile_positions(h, tile_size, tile_overlap) |
|
|
cols = _tile_positions(w, tile_size, tile_overlap) |
|
|
|
|
|
out_acc = torch.zeros(1, c, h, w) |
|
|
count = torch.zeros(1, 1, h, w) |
|
|
|
|
|
for y in rows: |
|
|
for x in cols: |
|
|
y_end = min(y + tile_size, h) |
|
|
x_end = min(x + tile_size, w) |
|
|
tile = lq[:, :, y:y_end, x:x_end] |
|
|
|
|
|
model.feed_data(data={"lq": tile}) |
|
|
model.test() |
|
|
tile_out = model.get_current_visuals()["result"] |
|
|
|
|
|
out_acc[:, :, y:y_end, x:x_end] += tile_out |
|
|
count[:, :, y:y_end, x:x_end] += 1.0 |
|
|
|
|
|
return out_acc / count.clamp(min=1.0) |
|
|
|
|
|
|
|
|
def _run_inference(model, lq: torch.Tensor, passes: int = 1) -> torch.Tensor: |
|
|
"""Run deblur inference with multiple passes for stronger effect.""" |
|
|
current = lq |
|
|
for _ in range(passes): |
|
|
current = _run_single_pass(model, current) |
|
|
return current |
|
|
|
|
|
|
|
|
def deblur(image: np.ndarray, strength: float, sharpen: float, passes: int): |
|
|
if image is None: |
|
|
raise gr.Error("Please upload an image.") |
|
|
|
|
|
try: |
|
|
model = _get_model() |
|
|
except FileNotFoundError as exc: |
|
|
raise gr.Error( |
|
|
"Model weights not found. Set MODEL_URL in the Space settings " |
|
|
"or add the weight file at experiments/pretrained_models/NAFNet-GoPro-width64.pth." |
|
|
) from exc |
|
|
except Exception as exc: |
|
|
raise gr.Error(f"Failed to load model: {exc}") from exc |
|
|
|
|
|
img_input = _normalize_input(image) |
|
|
if img_input.ndim == 2: |
|
|
img_input = cv2.cvtColor(img_input, cv2.COLOR_GRAY2RGB) |
|
|
if img_input.shape[2] == 4: |
|
|
img_input = cv2.cvtColor(img_input, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
inp = _img2tensor_rgb(img_input) |
|
|
|
|
|
try: |
|
|
result = _run_inference(model, inp.unsqueeze(dim=0), passes=int(passes)) |
|
|
sr_img = tensor2img([result], rgb2bgr=False) |
|
|
except RuntimeError as exc: |
|
|
if "out of memory" in str(exc).lower(): |
|
|
raise gr.Error( |
|
|
"Out of memory. Try uploading a smaller image or reducing passes." |
|
|
) from exc |
|
|
raise gr.Error(f"Inference failed: {exc}") from exc |
|
|
|
|
|
sr_img = _apply_strength(img_input, sr_img, strength) |
|
|
sr_img = _unsharp_mask(sr_img, sharpen) |
|
|
diff = _diff_map(img_input, sr_img) |
|
|
return sr_img, diff |
|
|
|
|
|
|
|
|
def build_ui(): |
|
|
with gr.Blocks(title="NAFNet Deblur") as demo: |
|
|
gr.Markdown( |
|
|
"# NAFNet Deblur\n" |
|
|
"Upload a blurry image and get a deblurred result.\n\n" |
|
|
"**Tips:** Increase **Strength** to amplify the effect. " |
|
|
"Raise **Sharpen** for extra crispness. " |
|
|
"Use **Passes** > 1 for heavily blurred images." |
|
|
) |
|
|
with gr.Row(): |
|
|
inp = gr.Image(label="Input (Blurry)", type="numpy") |
|
|
out = gr.Image(label="Output (Deblurred)", type="numpy") |
|
|
diff = gr.Image(label="Diff (x3)", type="numpy") |
|
|
with gr.Row(): |
|
|
strength = gr.Slider( |
|
|
0.5, 5.0, value=2.0, step=0.1, |
|
|
label="Strength (amplify deblur effect)") |
|
|
sharpen = gr.Slider( |
|
|
0.0, 2.0, value=0.5, step=0.05, |
|
|
label="Sharpen (post-processing)") |
|
|
passes = gr.Slider( |
|
|
1, 5, value=2, step=1, |
|
|
label="Passes (run model N times)") |
|
|
btn = gr.Button("Deblur", variant="primary") |
|
|
btn.click(fn=deblur, inputs=[inp, strength, sharpen, passes], outputs=[out, diff]) |
|
|
return demo |
|
|
|
|
|
|
|
|
app = build_ui() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |
|
|
|