File size: 4,035 Bytes
0200dd6
 
7612bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a9e469
7612bbc
 
0200dd6
 
7612bbc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class DeblurNet(nn.Module):
    def __init__(self):
        super(DeblurNet, self).__init__()
        
        # Encoder
        self.enc_conv1 = self.conv_block(3, 64)
        self.enc_conv2 = self.conv_block(64, 128)
        self.enc_conv3 = self.conv_block(128, 256)
        
        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)
        
        # Decoder with residual connections
        self.dec_conv1 = self.conv_block(512 + 256, 256)
        self.dec_conv2 = self.conv_block(256 + 128, 128)
        self.dec_conv3 = self.conv_block(128 + 64, 64)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        
        # Pooling and upsampling
        self.pool = nn.MaxPool2d(2, 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        x1 = self.enc_conv1(x)
        x2 = self.pool(x1)
        
        x2 = self.enc_conv2(x2)
        x3 = self.pool(x2)
        
        x3 = self.enc_conv3(x3)
        x4 = self.pool(x3)
        
        # Bottleneck
        x4 = self.bottleneck(x4)
        
        # Decoder with skip connections
        x = self.upsample(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.dec_conv1(x)
        
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)
        x = self.dec_conv2(x)
        
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.dec_conv3(x)
        
        x = self.final_conv(x)
        return torch.tanh(x)

# Load model
model = DeblurNet().to(device)
model_path = os.path.join('model', 'best_deblur_model.pth')
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# Image processing functions
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def postprocess_image(tensor):
    tensor = tensor * 0.5 + 0.5
    tensor = torch.clamp(tensor, 0, 1)
    image = tensor.cpu().detach().numpy()
    image = np.transpose(image, (1, 2, 0))
    return (image * 255).astype(np.uint8)

def deblur_image(input_image):
    if input_image is None:
        return None
    
    try:
        # Convert to PIL Image
        if isinstance(input_image, np.ndarray):
            input_image = Image.fromarray(input_image)
        
        # Save original size
        original_size = input_image.size
        
        # Preprocess
        input_tensor = transform(input_image).unsqueeze(0).to(device)
        
        # Inference
        with torch.no_grad():
            output_tensor = model(input_tensor)
        
        # Postprocess
        output_image = postprocess_image(output_tensor[0])
        
        # Resize back to original size
        output_image = Image.fromarray(output_image).resize(original_size)
        return np.array(output_image)
    
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

# Create Gradio interface
demo = gr.Interface(
    fn=deblur_image,
    inputs=gr.Image(type="numpy", label="Upload Blurry Image"),
    outputs=gr.Image(type="numpy", label="Deblurred Result"),
    title="Image Deblurring",
    description="Upload a blurry image and get it deblurred using deep learning.",
    examples=[
        ["examples/example1.jpg"]
    ]
)

if __name__ == "__main__":
    demo.launch()