Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| import torchvision | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import numpy as np | |
| from PIL import Image | |
| import yaml | |
| import os | |
| from models.networks import get_generator | |
| # =========================== | |
| # 1. Device setup | |
| # =========================== | |
| # Automatically choose GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"π₯ Using device: {device}") | |
| # =========================== | |
| # 2. Model Loading | |
| # =========================== | |
| def load_model(job_name="xyscannetp_gopro"): | |
| """ | |
| Load the pretrained XYScanNet model on CPU or GPU automatically. | |
| """ | |
| cfg_path = os.path.join("config", job_name, "config_stage2.yaml") | |
| with open(cfg_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| weights_path = os.path.join( | |
| "results", job_name, "models", f"best_{config['experiment_desc']}.pth" | |
| ) | |
| print(f"πΉ Loading model from {weights_path}") | |
| model = get_generator(config["model"]) | |
| model.load_state_dict(torch.load(weights_path, map_location=device)) | |
| model.eval().to(device) | |
| print(f"β Model loaded on {device}") | |
| return model | |
| print("Initializing XYScanNet model...") | |
| MODEL = load_model() | |
| print("Model ready.") | |
| # =========================== | |
| # 3. Helper functions | |
| # =========================== | |
| def pad_to_multiple_of_8(img_tensor): | |
| """ | |
| Pad the image tensor so that both height and width are multiples of 8. | |
| """ | |
| _, _, h, w = img_tensor.shape | |
| pad_h = (8 - h % 8) % 8 | |
| pad_w = (8 - w % 8) % 8 | |
| img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect") | |
| return img_tensor, h, w | |
| def crop_back(img_tensor, orig_h, orig_w): | |
| """Crop output back to original image size.""" | |
| return img_tensor[:, :, :orig_h, :orig_w] | |
| # =========================== | |
| # 4. Inference Function | |
| # =========================== | |
| # The decorator below *requests* GPU if available, | |
| # but won't crash if only CPU exists. | |
| def run_deblur(input_image: Image.Image): | |
| """ | |
| Run deblurring inference on GPU if available, else CPU. | |
| """ | |
| # Convert PIL RGB β Tensor [B,C,H,W] normalized to [-0.5,0.5] | |
| img = np.array(input_image.convert("RGB")) | |
| img_tensor = ( | |
| torch.from_numpy(np.transpose(img / 255.0, (2, 0, 1)).astype("float32")) - 0.5 | |
| ) | |
| img_tensor = Variable(img_tensor.unsqueeze(0)).to(device) | |
| # Pad to valid window size | |
| img_tensor, orig_h, orig_w = pad_to_multiple_of_8(img_tensor) | |
| # Inference | |
| with torch.no_grad(): | |
| result_image, _, _ = MODEL(img_tensor) | |
| result_image = result_image + 0.5 | |
| result_image = crop_back(result_image, orig_h, orig_w) | |
| # Convert to PIL Image for display | |
| out_img = result_image.squeeze(0).clamp(0, 1).cpu() | |
| out_pil = torchvision.transforms.ToPILImage()(out_img) | |
| return out_pil | |
| # =========================== | |
| # 5. Gradio Interface | |
| # =========================== | |
| demo = gr.Interface( | |
| fn=run_deblur, | |
| inputs=gr.Image(type="pil", label="Upload a Blurry Image"), | |
| outputs=gr.Image(type="pil", label="Deblurred Result"), | |
| title="XYScanNet: Mamba-based Image Deblurring (GPU Demo)", | |
| description=( | |
| "Upload a blurry image to see how XYScanNet restores it using a Mamba-based vision state-space model." | |
| ), | |
| examples=[ | |
| ["examples/blur1.png"], | |
| ["examples/blur2.png"], | |
| ["examples/blur3.png"], | |
| ["examples/blur4.png"], | |
| ["examples/blur5.png"], | |
| ], | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |