Deblur / app.py
pratyyush's picture
Update app.py
10f01a8 verified
import os
import warnings
from pathlib import Path
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Model definition
class DeblurNet(nn.Module):
def __init__(self):
super(DeblurNet, self).__init__()
self.enc_conv1 = self.conv_block(3, 64)
self.enc_conv2 = self.conv_block(64, 128)
self.enc_conv3 = self.conv_block(128, 256)
self.bottleneck = self.conv_block(256, 512)
self.dec_conv1 = self.conv_block(512 + 256, 256)
self.dec_conv2 = self.conv_block(256 + 128, 128)
self.dec_conv3 = self.conv_block(128 + 64, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc_conv1(x)
x2 = self.pool(x1)
x2 = self.enc_conv2(x2)
x3 = self.pool(x2)
x3 = self.enc_conv3(x3)
x4 = self.pool(x3)
x4 = self.bottleneck(x4)
x = self.upsample(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec_conv1(x)
x = self.upsample(x)
x = torch.cat([x, x2], dim=1)
x = self.dec_conv2(x)
x = self.upsample(x)
x = torch.cat([x, x1], dim=1)
x = self.dec_conv3(x)
x = self.final_conv(x)
return torch.tanh(x)
# Load model
model = DeblurNet().to(device)
model_path = os.path.join('model', 'best_deblur_model.pth')
# Ensure model path exists before loading
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("Model loaded successfully.")
else:
print(f"Model file not found at {model_path}. Please check the path.")
# Image processing functions
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def postprocess_image(tensor):
"""Post-process the output tensor into a displayable image."""
tensor = tensor * 0.5 + 0.5
tensor = torch.clamp(tensor, 0, 1)
image = tensor.cpu().detach().numpy()
image = np.transpose(image, (1, 2, 0))
return (image * 255).astype(np.uint8)
def deblur_image(filepath):
"""Deblurs the uploaded image."""
if not filepath:
return None
try:
# Load image from filepath
input_image = Image.open(filepath).convert("RGB")
# Save original size
original_size = input_image.size
# Preprocess
input_tensor = transform(input_image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output_tensor = model(input_tensor)
# Post-process
output_image = postprocess_image(output_tensor[0])
# Resize back to original size
output_image = Image.fromarray(output_image).resize(original_size)
return np.array(output_image)
except Exception as e:
print(f"Error processing image: {e}")
return None
# ✅ Your original CSS with fullscreen button removed
custom_css = """
/* Completely hide fullscreen and share buttons */
button[title="Fullscreen"],
button[title="Share"],
.gr-button[title="Fullscreen"],
.gr-button[title="Share"] {
display: none !important; /* Remove from the layout */
opacity: 0 !important; /* Make it invisible */
visibility: hidden !important; /* Ensure it's hidden */
width: 0 !important; /* Collapse the button size */
height: 0 !important; /* Collapse the button size */
overflow: hidden !important; /* Prevent any content visibility */
pointer-events: none !important; /* Disable all interactions */
}
/* Hide Gradio's footer and header */
footer, header, .gradio-footer, .gradio-header {
display: none !important;
}
/* Non-draggable images */
img {
pointer-events: none !important;
-webkit-user-drag: none !important;
user-select: none !important;
}
/* Styling adjustments */
body, .gradio-container {
background-color: #000000 !important;
color: white !important;
}
.gr-button {
background: #1e90ff !important;
color: white !important;
border: none !important;
padding: 10px 20px !important;
font-size: 14px !important;
cursor: pointer;
}
.gr-button:hover {
background: #0056b3 !important;
}
.gr-box, .gr-input, .gr-output {
background-color: #1c1c1c !important;
color: white !important;
border: 1px solid #333333 !important;
}
"""
# ✅ Gradio interface
demo = gr.Interface(
fn=deblur_image,
inputs=gr.File(label="Input", type="filepath"),
outputs=gr.Image(type="numpy", label="Deblurred Result"),
title="Image Deblurring",
description="Upload a blurry image.",
css=custom_css
)
# ✅ Launch Gradio app
if __name__ == "__main__":
demo.launch()