import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import numpy as np import os # Check if CUDA is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") class DeblurNet(nn.Module): def __init__(self): super(DeblurNet, self).__init__() # Encoder self.enc_conv1 = self.conv_block(3, 64) self.enc_conv2 = self.conv_block(64, 128) self.enc_conv3 = self.conv_block(128, 256) # Bottleneck self.bottleneck = self.conv_block(256, 512) # Decoder with residual connections self.dec_conv1 = self.conv_block(512 + 256, 256) self.dec_conv2 = self.conv_block(256 + 128, 128) self.dec_conv3 = self.conv_block(128 + 64, 64) self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1) # Pooling and upsampling self.pool = nn.MaxPool2d(2, 2) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) def conv_block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) def forward(self, x): # Encoder x1 = self.enc_conv1(x) x2 = self.pool(x1) x2 = self.enc_conv2(x2) x3 = self.pool(x2) x3 = self.enc_conv3(x3) x4 = self.pool(x3) # Bottleneck x4 = self.bottleneck(x4) # Decoder with skip connections x = self.upsample(x4) x = torch.cat([x, x3], dim=1) x = self.dec_conv1(x) x = self.upsample(x) x = torch.cat([x, x2], dim=1) x = self.dec_conv2(x) x = self.upsample(x) x = torch.cat([x, x1], dim=1) x = self.dec_conv3(x) x = self.final_conv(x) return torch.tanh(x) # Load model model = DeblurNet().to(device) model_path = os.path.join('model', 'best_deblur_model.pth') model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # Image processing functions transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def postprocess_image(tensor): tensor = tensor * 0.5 + 0.5 tensor = torch.clamp(tensor, 0, 1) image = tensor.cpu().detach().numpy() image = np.transpose(image, (1, 2, 0)) return (image * 255).astype(np.uint8) def deblur_image(input_image): if input_image is None: return None try: # Convert to PIL Image if isinstance(input_image, np.ndarray): input_image = Image.fromarray(input_image) # Save original size original_size = input_image.size # Preprocess input_tensor = transform(input_image).unsqueeze(0).to(device) # Inference with torch.no_grad(): output_tensor = model(input_tensor) # Postprocess output_image = postprocess_image(output_tensor[0]) # Resize back to original size output_image = Image.fromarray(output_image).resize(original_size) return np.array(output_image) except Exception as e: print(f"Error processing image: {e}") return None # Create Gradio interface demo = gr.Interface( fn=deblur_image, inputs=gr.Image(type="numpy", label="Upload Blurry Image"), outputs=gr.Image(type="numpy", label="Deblurred Result"), title="Image Deblurring", description="Upload a blurry image and get it deblurred using deep learning.", examples=[ ["examples/example1.jpg"] ] ) if __name__ == "__main__": demo.launch()