File size: 3,594 Bytes
b56342d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c25769
b56342d
2c25769
 
 
b56342d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import spaces
import torch
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from PIL import Image
import yaml
import os
from models.networks import get_generator


# ===========================
# 1. Device setup
# ===========================
# Automatically choose GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔥 Using device: {device}")


# ===========================
# 2. Model Loading
# ===========================
def load_model(job_name="xyscannetp_gopro"):
    """
    Load the pretrained XYScanNet model on CPU or GPU automatically.
    """
    cfg_path = os.path.join("config", job_name, "config_stage2.yaml")
    with open(cfg_path, "r") as f:
        config = yaml.safe_load(f)

    weights_path = os.path.join(
        "results", job_name, "models", f"best_{config['experiment_desc']}.pth"
    )
    print(f"🔹 Loading model from {weights_path}")

    model = get_generator(config["model"])
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval().to(device)

    print(f"✅ Model loaded on {device}")
    return model


print("Initializing XYScanNet model...")
MODEL = load_model()
print("Model ready.")


# ===========================
# 3. Helper functions
# ===========================
def pad_to_multiple_of_8(img_tensor):
    """
    Pad the image tensor so that both height and width are multiples of 8.
    """
    _, _, h, w = img_tensor.shape
    pad_h = (8 - h % 8) % 8
    pad_w = (8 - w % 8) % 8
    img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect")
    return img_tensor, h, w


def crop_back(img_tensor, orig_h, orig_w):
    """Crop output back to original image size."""
    return img_tensor[:, :, :orig_h, :orig_w]


# ===========================
# 4. Inference Function
# ===========================
# The decorator below *requests* GPU if available,
# but won't crash if only CPU exists.
@spaces.GPU
def run_deblur(input_image: Image.Image):
    """
    Run deblurring inference on GPU if available, else CPU.
    """
    # Convert PIL RGB → Tensor [B,C,H,W] normalized to [-0.5,0.5]
    img = np.array(input_image.convert("RGB"))
    img_tensor = (
        torch.from_numpy(np.transpose(img / 255.0, (2, 0, 1)).astype("float32")) - 0.5
    )
    img_tensor = Variable(img_tensor.unsqueeze(0)).to(device)

    # Pad to valid window size
    img_tensor, orig_h, orig_w = pad_to_multiple_of_8(img_tensor)

    # Inference
    with torch.no_grad():
        result_image, _, _ = MODEL(img_tensor)
        result_image = result_image + 0.5
        result_image = crop_back(result_image, orig_h, orig_w)

    # Convert to PIL Image for display
    out_img = result_image.squeeze(0).clamp(0, 1).cpu()
    out_pil = torchvision.transforms.ToPILImage()(out_img)
    return out_pil


# ===========================
# 5. Gradio Interface
# ===========================
demo = gr.Interface(
    fn=run_deblur,
    inputs=gr.Image(type="pil", label="Upload a Blurry Image"),
    outputs=gr.Image(type="pil", label="Deblurred Result"),
    title="XYScanNet: Mamba-based Image Deblurring (GPU Demo)",
    description=(
        "Upload a blurry image to see how XYScanNet restores it using a Mamba-based vision state-space model."
    ),
    examples=[
        ["examples/blur1.png"],
        ["examples/blur2.png"],
        ["examples/blur3.png"],
        ["examples/blur4.png"],
        ["examples/blur5.png"],
    ],
    allow_flagging="never",
)
if __name__ == "__main__":
    demo.launch()