|
|
import os |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
class DeblurNet(nn.Module): |
|
|
def __init__(self): |
|
|
super(DeblurNet, self).__init__() |
|
|
self.enc_conv1 = self.conv_block(3, 64) |
|
|
self.enc_conv2 = self.conv_block(64, 128) |
|
|
self.enc_conv3 = self.conv_block(128, 256) |
|
|
self.bottleneck = self.conv_block(256, 512) |
|
|
self.dec_conv1 = self.conv_block(512 + 256, 256) |
|
|
self.dec_conv2 = self.conv_block(256 + 128, 128) |
|
|
self.dec_conv3 = self.conv_block(128 + 64, 64) |
|
|
self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1) |
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
|
|
|
|
def conv_block(self, in_channels, out_channels): |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.enc_conv1(x) |
|
|
x2 = self.pool(x1) |
|
|
x2 = self.enc_conv2(x2) |
|
|
x3 = self.pool(x2) |
|
|
x3 = self.enc_conv3(x3) |
|
|
x4 = self.pool(x3) |
|
|
x4 = self.bottleneck(x4) |
|
|
x = self.upsample(x4) |
|
|
x = torch.cat([x, x3], dim=1) |
|
|
x = self.dec_conv1(x) |
|
|
x = self.upsample(x) |
|
|
x = torch.cat([x, x2], dim=1) |
|
|
x = self.dec_conv2(x) |
|
|
x = self.upsample(x) |
|
|
x = torch.cat([x, x1], dim=1) |
|
|
x = self.dec_conv3(x) |
|
|
x = self.final_conv(x) |
|
|
return torch.tanh(x) |
|
|
|
|
|
|
|
|
model = DeblurNet().to(device) |
|
|
model_path = os.path.join('model', 'best_deblur_model.pth') |
|
|
|
|
|
|
|
|
if os.path.exists(model_path): |
|
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
model.eval() |
|
|
print("Model loaded successfully.") |
|
|
else: |
|
|
print(f"Model file not found at {model_path}. Please check the path.") |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((256, 256)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
def postprocess_image(tensor): |
|
|
"""Post-process the output tensor into a displayable image.""" |
|
|
tensor = tensor * 0.5 + 0.5 |
|
|
tensor = torch.clamp(tensor, 0, 1) |
|
|
image = tensor.cpu().detach().numpy() |
|
|
image = np.transpose(image, (1, 2, 0)) |
|
|
return (image * 255).astype(np.uint8) |
|
|
|
|
|
def deblur_image(filepath): |
|
|
"""Deblurs the uploaded image.""" |
|
|
if not filepath: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
input_image = Image.open(filepath).convert("RGB") |
|
|
|
|
|
|
|
|
original_size = input_image.size |
|
|
|
|
|
|
|
|
input_tensor = transform(input_image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_tensor = model(input_tensor) |
|
|
|
|
|
|
|
|
output_image = postprocess_image(output_tensor[0]) |
|
|
|
|
|
|
|
|
output_image = Image.fromarray(output_image).resize(original_size) |
|
|
|
|
|
return np.array(output_image) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
/* Completely hide fullscreen and share buttons */ |
|
|
button[title="Fullscreen"], |
|
|
button[title="Share"], |
|
|
.gr-button[title="Fullscreen"], |
|
|
.gr-button[title="Share"] { |
|
|
display: none !important; /* Remove from the layout */ |
|
|
opacity: 0 !important; /* Make it invisible */ |
|
|
visibility: hidden !important; /* Ensure it's hidden */ |
|
|
width: 0 !important; /* Collapse the button size */ |
|
|
height: 0 !important; /* Collapse the button size */ |
|
|
overflow: hidden !important; /* Prevent any content visibility */ |
|
|
pointer-events: none !important; /* Disable all interactions */ |
|
|
} |
|
|
/* Hide Gradio's footer and header */ |
|
|
footer, header, .gradio-footer, .gradio-header { |
|
|
display: none !important; |
|
|
} |
|
|
/* Non-draggable images */ |
|
|
img { |
|
|
pointer-events: none !important; |
|
|
-webkit-user-drag: none !important; |
|
|
user-select: none !important; |
|
|
} |
|
|
/* Styling adjustments */ |
|
|
body, .gradio-container { |
|
|
background-color: #000000 !important; |
|
|
color: white !important; |
|
|
} |
|
|
.gr-button { |
|
|
background: #1e90ff !important; |
|
|
color: white !important; |
|
|
border: none !important; |
|
|
padding: 10px 20px !important; |
|
|
font-size: 14px !important; |
|
|
cursor: pointer; |
|
|
} |
|
|
.gr-button:hover { |
|
|
background: #0056b3 !important; |
|
|
} |
|
|
.gr-box, .gr-input, .gr-output { |
|
|
background-color: #1c1c1c !important; |
|
|
color: white !important; |
|
|
border: 1px solid #333333 !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=deblur_image, |
|
|
inputs=gr.File(label="Input", type="filepath"), |
|
|
outputs=gr.Image(type="numpy", label="Deblurred Result"), |
|
|
title="Image Deblurring", |
|
|
description="Upload a blurry image.", |
|
|
css=custom_css |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|