NatiTs's picture
second commit
2a9e469
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
class DeblurNet(nn.Module):
def __init__(self):
super(DeblurNet, self).__init__()
# Encoder
self.enc_conv1 = self.conv_block(3, 64)
self.enc_conv2 = self.conv_block(64, 128)
self.enc_conv3 = self.conv_block(128, 256)
# Bottleneck
self.bottleneck = self.conv_block(256, 512)
# Decoder with residual connections
self.dec_conv1 = self.conv_block(512 + 256, 256)
self.dec_conv2 = self.conv_block(256 + 128, 128)
self.dec_conv3 = self.conv_block(128 + 64, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
# Pooling and upsampling
self.pool = nn.MaxPool2d(2, 2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
# Encoder
x1 = self.enc_conv1(x)
x2 = self.pool(x1)
x2 = self.enc_conv2(x2)
x3 = self.pool(x2)
x3 = self.enc_conv3(x3)
x4 = self.pool(x3)
# Bottleneck
x4 = self.bottleneck(x4)
# Decoder with skip connections
x = self.upsample(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec_conv1(x)
x = self.upsample(x)
x = torch.cat([x, x2], dim=1)
x = self.dec_conv2(x)
x = self.upsample(x)
x = torch.cat([x, x1], dim=1)
x = self.dec_conv3(x)
x = self.final_conv(x)
return torch.tanh(x)
# Load model
model = DeblurNet().to(device)
model_path = os.path.join('model', 'best_deblur_model.pth')
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# Image processing functions
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def postprocess_image(tensor):
tensor = tensor * 0.5 + 0.5
tensor = torch.clamp(tensor, 0, 1)
image = tensor.cpu().detach().numpy()
image = np.transpose(image, (1, 2, 0))
return (image * 255).astype(np.uint8)
def deblur_image(input_image):
if input_image is None:
return None
try:
# Convert to PIL Image
if isinstance(input_image, np.ndarray):
input_image = Image.fromarray(input_image)
# Save original size
original_size = input_image.size
# Preprocess
input_tensor = transform(input_image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output_tensor = model(input_tensor)
# Postprocess
output_image = postprocess_image(output_tensor[0])
# Resize back to original size
output_image = Image.fromarray(output_image).resize(original_size)
return np.array(output_image)
except Exception as e:
print(f"Error processing image: {e}")
return None
# Create Gradio interface
demo = gr.Interface(
fn=deblur_image,
inputs=gr.Image(type="numpy", label="Upload Blurry Image"),
outputs=gr.Image(type="numpy", label="Deblurred Result"),
title="Image Deblurring",
description="Upload a blurry image and get it deblurred using deep learning.",
examples=[
["examples/example1.jpg"]
]
)
if __name__ == "__main__":
demo.launch()