import gradio as gr import spaces import torch import torchvision import torch.nn.functional as F from torch.autograd import Variable import numpy as np from PIL import Image import yaml import os from models.networks import get_generator # =========================== # 1. Device setup # =========================== # Automatically choose GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🔥 Using device: {device}") # =========================== # 2. Model Loading # =========================== def load_model(job_name="xyscannetp_gopro"): """ Load the pretrained XYScanNet model on CPU or GPU automatically. """ cfg_path = os.path.join("config", job_name, "config_stage2.yaml") with open(cfg_path, "r") as f: config = yaml.safe_load(f) weights_path = os.path.join( "results", job_name, "models", f"best_{config['experiment_desc']}.pth" ) print(f"🔹 Loading model from {weights_path}") model = get_generator(config["model"]) model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval().to(device) print(f"✅ Model loaded on {device}") return model print("Initializing XYScanNet model...") MODEL = load_model() print("Model ready.") # =========================== # 3. Helper functions # =========================== def pad_to_multiple_of_8(img_tensor): """ Pad the image tensor so that both height and width are multiples of 8. """ _, _, h, w = img_tensor.shape pad_h = (8 - h % 8) % 8 pad_w = (8 - w % 8) % 8 img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect") return img_tensor, h, w def crop_back(img_tensor, orig_h, orig_w): """Crop output back to original image size.""" return img_tensor[:, :, :orig_h, :orig_w] # =========================== # 4. Inference Function # =========================== # The decorator below *requests* GPU if available, # but won't crash if only CPU exists. @spaces.GPU def run_deblur(input_image: Image.Image): """ Run deblurring inference on GPU if available, else CPU. """ # Convert PIL RGB → Tensor [B,C,H,W] normalized to [-0.5,0.5] img = np.array(input_image.convert("RGB")) img_tensor = ( torch.from_numpy(np.transpose(img / 255.0, (2, 0, 1)).astype("float32")) - 0.5 ) img_tensor = Variable(img_tensor.unsqueeze(0)).to(device) # Pad to valid window size img_tensor, orig_h, orig_w = pad_to_multiple_of_8(img_tensor) # Inference with torch.no_grad(): result_image, _, _ = MODEL(img_tensor) result_image = result_image + 0.5 result_image = crop_back(result_image, orig_h, orig_w) # Convert to PIL Image for display out_img = result_image.squeeze(0).clamp(0, 1).cpu() out_pil = torchvision.transforms.ToPILImage()(out_img) return out_pil # =========================== # 5. Gradio Interface # =========================== demo = gr.Interface( fn=run_deblur, inputs=gr.Image(type="pil", label="Upload a Blurry Image"), outputs=gr.Image(type="pil", label="Deblurred Result"), title="XYScanNet: Mamba-based Image Deblurring (GPU Demo)", description=( "Upload a blurry image to see how XYScanNet restores it using a Mamba-based vision state-space model." ), examples=[ ["examples/blur1.png"], ["examples/blur2.png"], ["examples/blur3.png"], ["examples/blur4.png"], ["examples/blur5.png"], ], allow_flagging="never", ) if __name__ == "__main__": demo.launch()