File size: 5,453 Bytes
2ea3bc2 0200dd6 7612bbc d2bb28a 7612bbc d2bb28a 7612bbc 72e804e 7612bbc 72e804e 7612bbc d2bb28a 7612bbc 0b84bc4 d2bb28a 7612bbc 0b84bc4 7612bbc c7ffd84 0b84bc4 7612bbc 0b84bc4 7612bbc 0b84bc4 c7ffd84 e17de0c 7612bbc 3abc1d7 e17de0c 7612bbc 3abc1d7 e17de0c 7612bbc e17de0c 0b84bc4 7612bbc e17de0c 7612bbc 0b84bc4 3abc1d7 e17de0c 7612bbc 10f01a8 72e804e 3079e81 72e804e e17de0c 72e804e 2ea3bc2 7612bbc 68811da 3abc1d7 7612bbc 3079e81 08fa660 7612bbc 0200dd6 2ea3bc2 0200dd6 10f01a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import os
import warnings
from pathlib import Path
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Model definition
class DeblurNet(nn.Module):
def __init__(self):
super(DeblurNet, self).__init__()
self.enc_conv1 = self.conv_block(3, 64)
self.enc_conv2 = self.conv_block(64, 128)
self.enc_conv3 = self.conv_block(128, 256)
self.bottleneck = self.conv_block(256, 512)
self.dec_conv1 = self.conv_block(512 + 256, 256)
self.dec_conv2 = self.conv_block(256 + 128, 128)
self.dec_conv3 = self.conv_block(128 + 64, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc_conv1(x)
x2 = self.pool(x1)
x2 = self.enc_conv2(x2)
x3 = self.pool(x2)
x3 = self.enc_conv3(x3)
x4 = self.pool(x3)
x4 = self.bottleneck(x4)
x = self.upsample(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec_conv1(x)
x = self.upsample(x)
x = torch.cat([x, x2], dim=1)
x = self.dec_conv2(x)
x = self.upsample(x)
x = torch.cat([x, x1], dim=1)
x = self.dec_conv3(x)
x = self.final_conv(x)
return torch.tanh(x)
# Load model
model = DeblurNet().to(device)
model_path = os.path.join('model', 'best_deblur_model.pth')
# Ensure model path exists before loading
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("Model loaded successfully.")
else:
print(f"Model file not found at {model_path}. Please check the path.")
# Image processing functions
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def postprocess_image(tensor):
"""Post-process the output tensor into a displayable image."""
tensor = tensor * 0.5 + 0.5
tensor = torch.clamp(tensor, 0, 1)
image = tensor.cpu().detach().numpy()
image = np.transpose(image, (1, 2, 0))
return (image * 255).astype(np.uint8)
def deblur_image(filepath):
"""Deblurs the uploaded image."""
if not filepath:
return None
try:
# Load image from filepath
input_image = Image.open(filepath).convert("RGB")
# Save original size
original_size = input_image.size
# Preprocess
input_tensor = transform(input_image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output_tensor = model(input_tensor)
# Post-process
output_image = postprocess_image(output_tensor[0])
# Resize back to original size
output_image = Image.fromarray(output_image).resize(original_size)
return np.array(output_image)
except Exception as e:
print(f"Error processing image: {e}")
return None
# ✅ Your original CSS with fullscreen button removed
custom_css = """
/* Completely hide fullscreen and share buttons */
button[title="Fullscreen"],
button[title="Share"],
.gr-button[title="Fullscreen"],
.gr-button[title="Share"] {
display: none !important; /* Remove from the layout */
opacity: 0 !important; /* Make it invisible */
visibility: hidden !important; /* Ensure it's hidden */
width: 0 !important; /* Collapse the button size */
height: 0 !important; /* Collapse the button size */
overflow: hidden !important; /* Prevent any content visibility */
pointer-events: none !important; /* Disable all interactions */
}
/* Hide Gradio's footer and header */
footer, header, .gradio-footer, .gradio-header {
display: none !important;
}
/* Non-draggable images */
img {
pointer-events: none !important;
-webkit-user-drag: none !important;
user-select: none !important;
}
/* Styling adjustments */
body, .gradio-container {
background-color: #000000 !important;
color: white !important;
}
.gr-button {
background: #1e90ff !important;
color: white !important;
border: none !important;
padding: 10px 20px !important;
font-size: 14px !important;
cursor: pointer;
}
.gr-button:hover {
background: #0056b3 !important;
}
.gr-box, .gr-input, .gr-output {
background-color: #1c1c1c !important;
color: white !important;
border: 1px solid #333333 !important;
}
"""
# ✅ Gradio interface
demo = gr.Interface(
fn=deblur_image,
inputs=gr.File(label="Input", type="filepath"),
outputs=gr.Image(type="numpy", label="Deblurred Result"),
title="Image Deblurring",
description="Upload a blurry image.",
css=custom_css
)
# ✅ Launch Gradio app
if __name__ == "__main__":
demo.launch()
|