XYScanNet_Demo / app.py
HanzhouLiu
png in app py
2c25769
import gradio as gr
import spaces
import torch
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from PIL import Image
import yaml
import os
from models.networks import get_generator
# ===========================
# 1. Device setup
# ===========================
# Automatically choose GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"πŸ”₯ Using device: {device}")
# ===========================
# 2. Model Loading
# ===========================
def load_model(job_name="xyscannetp_gopro"):
"""
Load the pretrained XYScanNet model on CPU or GPU automatically.
"""
cfg_path = os.path.join("config", job_name, "config_stage2.yaml")
with open(cfg_path, "r") as f:
config = yaml.safe_load(f)
weights_path = os.path.join(
"results", job_name, "models", f"best_{config['experiment_desc']}.pth"
)
print(f"πŸ”Ή Loading model from {weights_path}")
model = get_generator(config["model"])
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval().to(device)
print(f"βœ… Model loaded on {device}")
return model
print("Initializing XYScanNet model...")
MODEL = load_model()
print("Model ready.")
# ===========================
# 3. Helper functions
# ===========================
def pad_to_multiple_of_8(img_tensor):
"""
Pad the image tensor so that both height and width are multiples of 8.
"""
_, _, h, w = img_tensor.shape
pad_h = (8 - h % 8) % 8
pad_w = (8 - w % 8) % 8
img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect")
return img_tensor, h, w
def crop_back(img_tensor, orig_h, orig_w):
"""Crop output back to original image size."""
return img_tensor[:, :, :orig_h, :orig_w]
# ===========================
# 4. Inference Function
# ===========================
# The decorator below *requests* GPU if available,
# but won't crash if only CPU exists.
@spaces.GPU
def run_deblur(input_image: Image.Image):
"""
Run deblurring inference on GPU if available, else CPU.
"""
# Convert PIL RGB β†’ Tensor [B,C,H,W] normalized to [-0.5,0.5]
img = np.array(input_image.convert("RGB"))
img_tensor = (
torch.from_numpy(np.transpose(img / 255.0, (2, 0, 1)).astype("float32")) - 0.5
)
img_tensor = Variable(img_tensor.unsqueeze(0)).to(device)
# Pad to valid window size
img_tensor, orig_h, orig_w = pad_to_multiple_of_8(img_tensor)
# Inference
with torch.no_grad():
result_image, _, _ = MODEL(img_tensor)
result_image = result_image + 0.5
result_image = crop_back(result_image, orig_h, orig_w)
# Convert to PIL Image for display
out_img = result_image.squeeze(0).clamp(0, 1).cpu()
out_pil = torchvision.transforms.ToPILImage()(out_img)
return out_pil
# ===========================
# 5. Gradio Interface
# ===========================
demo = gr.Interface(
fn=run_deblur,
inputs=gr.Image(type="pil", label="Upload a Blurry Image"),
outputs=gr.Image(type="pil", label="Deblurred Result"),
title="XYScanNet: Mamba-based Image Deblurring (GPU Demo)",
description=(
"Upload a blurry image to see how XYScanNet restores it using a Mamba-based vision state-space model."
),
examples=[
["examples/blur1.png"],
["examples/blur2.png"],
["examples/blur3.png"],
["examples/blur4.png"],
["examples/blur5.png"],
],
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()