Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| # Check if CUDA is available | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| class DeblurNet(nn.Module): | |
| def __init__(self): | |
| super(DeblurNet, self).__init__() | |
| # Encoder | |
| self.enc_conv1 = self.conv_block(3, 64) | |
| self.enc_conv2 = self.conv_block(64, 128) | |
| self.enc_conv3 = self.conv_block(128, 256) | |
| # Bottleneck | |
| self.bottleneck = self.conv_block(256, 512) | |
| # Decoder with residual connections | |
| self.dec_conv1 = self.conv_block(512 + 256, 256) | |
| self.dec_conv2 = self.conv_block(256 + 128, 128) | |
| self.dec_conv3 = self.conv_block(128 + 64, 64) | |
| self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1) | |
| # Pooling and upsampling | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| def conv_block(self, in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| # Encoder | |
| x1 = self.enc_conv1(x) | |
| x2 = self.pool(x1) | |
| x2 = self.enc_conv2(x2) | |
| x3 = self.pool(x2) | |
| x3 = self.enc_conv3(x3) | |
| x4 = self.pool(x3) | |
| # Bottleneck | |
| x4 = self.bottleneck(x4) | |
| # Decoder with skip connections | |
| x = self.upsample(x4) | |
| x = torch.cat([x, x3], dim=1) | |
| x = self.dec_conv1(x) | |
| x = self.upsample(x) | |
| x = torch.cat([x, x2], dim=1) | |
| x = self.dec_conv2(x) | |
| x = self.upsample(x) | |
| x = torch.cat([x, x1], dim=1) | |
| x = self.dec_conv3(x) | |
| x = self.final_conv(x) | |
| return torch.tanh(x) | |
| # Load model | |
| model = DeblurNet().to(device) | |
| model_path = os.path.join('model', 'best_deblur_model.pth') | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| # Image processing functions | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| def postprocess_image(tensor): | |
| tensor = tensor * 0.5 + 0.5 | |
| tensor = torch.clamp(tensor, 0, 1) | |
| image = tensor.cpu().detach().numpy() | |
| image = np.transpose(image, (1, 2, 0)) | |
| return (image * 255).astype(np.uint8) | |
| def deblur_image(input_image): | |
| if input_image is None: | |
| return None | |
| try: | |
| # Convert to PIL Image | |
| if isinstance(input_image, np.ndarray): | |
| input_image = Image.fromarray(input_image) | |
| # Save original size | |
| original_size = input_image.size | |
| # Preprocess | |
| input_tensor = transform(input_image).unsqueeze(0).to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| output_tensor = model(input_tensor) | |
| # Postprocess | |
| output_image = postprocess_image(output_tensor[0]) | |
| # Resize back to original size | |
| output_image = Image.fromarray(output_image).resize(original_size) | |
| return np.array(output_image) | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| return None | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=deblur_image, | |
| inputs=gr.Image(type="numpy", label="Upload Blurry Image"), | |
| outputs=gr.Image(type="numpy", label="Deblurred Result"), | |
| title="Image Deblurring", | |
| description="Upload a blurry image and get it deblurred using deep learning.", | |
| examples=[ | |
| ["examples/example1.jpg"] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |