File size: 5,453 Bytes
2ea3bc2
 
 
 
0200dd6
 
7612bbc
 
 
 
 
d2bb28a
7612bbc
 
 
d2bb28a
7612bbc
 
 
 
 
 
 
 
 
 
 
 
 
72e804e
7612bbc
 
 
 
 
 
 
72e804e
7612bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2bb28a
7612bbc
 
 
0b84bc4
 
 
 
 
 
 
 
d2bb28a
7612bbc
 
 
 
 
 
 
0b84bc4
7612bbc
 
 
 
 
 
c7ffd84
0b84bc4
 
7612bbc
0b84bc4
7612bbc
0b84bc4
c7ffd84
e17de0c
7612bbc
3abc1d7
e17de0c
7612bbc
3abc1d7
e17de0c
7612bbc
 
 
e17de0c
0b84bc4
7612bbc
e17de0c
7612bbc
 
0b84bc4
3abc1d7
e17de0c
7612bbc
 
 
 
10f01a8
72e804e
3079e81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e804e
e17de0c
72e804e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea3bc2
7612bbc
 
68811da
3abc1d7
7612bbc
3079e81
08fa660
7612bbc
0200dd6
2ea3bc2
0200dd6
10f01a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import warnings
from pathlib import Path

import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model definition
class DeblurNet(nn.Module):
    def __init__(self):
        super(DeblurNet, self).__init__()
        self.enc_conv1 = self.conv_block(3, 64)
        self.enc_conv2 = self.conv_block(64, 128)
        self.enc_conv3 = self.conv_block(128, 256)
        self.bottleneck = self.conv_block(256, 512)
        self.dec_conv1 = self.conv_block(512 + 256, 256)
        self.dec_conv2 = self.conv_block(256 + 128, 128)
        self.dec_conv3 = self.conv_block(128 + 64, 64)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.enc_conv1(x)
        x2 = self.pool(x1)
        x2 = self.enc_conv2(x2)
        x3 = self.pool(x2)
        x3 = self.enc_conv3(x3)
        x4 = self.pool(x3)
        x4 = self.bottleneck(x4)
        x = self.upsample(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.dec_conv1(x)
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)
        x = self.dec_conv2(x)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.dec_conv3(x)
        x = self.final_conv(x)
        return torch.tanh(x)

# Load model
model = DeblurNet().to(device)
model_path = os.path.join('model', 'best_deblur_model.pth')

# Ensure model path exists before loading
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print("Model loaded successfully.")
else:
    print(f"Model file not found at {model_path}. Please check the path.")

# Image processing functions
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def postprocess_image(tensor):
    """Post-process the output tensor into a displayable image."""
    tensor = tensor * 0.5 + 0.5
    tensor = torch.clamp(tensor, 0, 1)
    image = tensor.cpu().detach().numpy()
    image = np.transpose(image, (1, 2, 0))
    return (image * 255).astype(np.uint8)

def deblur_image(filepath):
    """Deblurs the uploaded image."""
    if not filepath:
        return None

    try:
        # Load image from filepath
        input_image = Image.open(filepath).convert("RGB")

        # Save original size
        original_size = input_image.size

        # Preprocess
        input_tensor = transform(input_image).unsqueeze(0).to(device)

        # Inference
        with torch.no_grad():
            output_tensor = model(input_tensor)

        # Post-process
        output_image = postprocess_image(output_tensor[0])

        # Resize back to original size
        output_image = Image.fromarray(output_image).resize(original_size)

        return np.array(output_image)

    except Exception as e:
        print(f"Error processing image: {e}")
        return None

# ✅ Your original CSS with fullscreen button removed
custom_css = """
/* Completely hide fullscreen and share buttons */
button[title="Fullscreen"], 
button[title="Share"],
.gr-button[title="Fullscreen"],
.gr-button[title="Share"] {
    display: none !important;          /* Remove from the layout */
    opacity: 0 !important;             /* Make it invisible */
    visibility: hidden !important;     /* Ensure it's hidden */
    width: 0 !important;               /* Collapse the button size */
    height: 0 !important;              /* Collapse the button size */
    overflow: hidden !important;       /* Prevent any content visibility */
    pointer-events: none !important;   /* Disable all interactions */
}
/* Hide Gradio's footer and header */
footer, header, .gradio-footer, .gradio-header {
    display: none !important;
}
/* Non-draggable images */
img {
    pointer-events: none !important;
    -webkit-user-drag: none !important;
    user-select: none !important;
}
/* Styling adjustments */
body, .gradio-container {
    background-color: #000000 !important;
    color: white !important;
}
.gr-button {
    background: #1e90ff !important;
    color: white !important;
    border: none !important;
    padding: 10px 20px !important;
    font-size: 14px !important;
    cursor: pointer;
}
.gr-button:hover {
    background: #0056b3 !important;
}
.gr-box, .gr-input, .gr-output {
    background-color: #1c1c1c !important;
    color: white !important;
    border: 1px solid #333333 !important;
}
"""

# ✅ Gradio interface
demo = gr.Interface(
    fn=deblur_image,
    inputs=gr.File(label="Input", type="filepath"),
    outputs=gr.Image(type="numpy", label="Deblurred Result"),
    title="Image Deblurring",
    description="Upload a blurry image.",
    css=custom_css
)

# ✅ Launch Gradio app
if __name__ == "__main__":
    demo.launch()