import os import warnings from pathlib import Path import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import numpy as np # Check if CUDA is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Model definition class DeblurNet(nn.Module): def __init__(self): super(DeblurNet, self).__init__() self.enc_conv1 = self.conv_block(3, 64) self.enc_conv2 = self.conv_block(64, 128) self.enc_conv3 = self.conv_block(128, 256) self.bottleneck = self.conv_block(256, 512) self.dec_conv1 = self.conv_block(512 + 256, 256) self.dec_conv2 = self.conv_block(256 + 128, 128) self.dec_conv3 = self.conv_block(128 + 64, 64) self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) def conv_block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) def forward(self, x): x1 = self.enc_conv1(x) x2 = self.pool(x1) x2 = self.enc_conv2(x2) x3 = self.pool(x2) x3 = self.enc_conv3(x3) x4 = self.pool(x3) x4 = self.bottleneck(x4) x = self.upsample(x4) x = torch.cat([x, x3], dim=1) x = self.dec_conv1(x) x = self.upsample(x) x = torch.cat([x, x2], dim=1) x = self.dec_conv2(x) x = self.upsample(x) x = torch.cat([x, x1], dim=1) x = self.dec_conv3(x) x = self.final_conv(x) return torch.tanh(x) # Load model model = DeblurNet().to(device) model_path = os.path.join('model', 'best_deblur_model.pth') # Ensure model path exists before loading if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() print("Model loaded successfully.") else: print(f"Model file not found at {model_path}. Please check the path.") # Image processing functions transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def postprocess_image(tensor): """Post-process the output tensor into a displayable image.""" tensor = tensor * 0.5 + 0.5 tensor = torch.clamp(tensor, 0, 1) image = tensor.cpu().detach().numpy() image = np.transpose(image, (1, 2, 0)) return (image * 255).astype(np.uint8) def deblur_image(filepath): """Deblurs the uploaded image.""" if not filepath: return None try: # Load image from filepath input_image = Image.open(filepath).convert("RGB") # Save original size original_size = input_image.size # Preprocess input_tensor = transform(input_image).unsqueeze(0).to(device) # Inference with torch.no_grad(): output_tensor = model(input_tensor) # Post-process output_image = postprocess_image(output_tensor[0]) # Resize back to original size output_image = Image.fromarray(output_image).resize(original_size) return np.array(output_image) except Exception as e: print(f"Error processing image: {e}") return None # ✅ Your original CSS with fullscreen button removed custom_css = """ /* Completely hide fullscreen and share buttons */ button[title="Fullscreen"], button[title="Share"], .gr-button[title="Fullscreen"], .gr-button[title="Share"] { display: none !important; /* Remove from the layout */ opacity: 0 !important; /* Make it invisible */ visibility: hidden !important; /* Ensure it's hidden */ width: 0 !important; /* Collapse the button size */ height: 0 !important; /* Collapse the button size */ overflow: hidden !important; /* Prevent any content visibility */ pointer-events: none !important; /* Disable all interactions */ } /* Hide Gradio's footer and header */ footer, header, .gradio-footer, .gradio-header { display: none !important; } /* Non-draggable images */ img { pointer-events: none !important; -webkit-user-drag: none !important; user-select: none !important; } /* Styling adjustments */ body, .gradio-container { background-color: #000000 !important; color: white !important; } .gr-button { background: #1e90ff !important; color: white !important; border: none !important; padding: 10px 20px !important; font-size: 14px !important; cursor: pointer; } .gr-button:hover { background: #0056b3 !important; } .gr-box, .gr-input, .gr-output { background-color: #1c1c1c !important; color: white !important; border: 1px solid #333333 !important; } """ # ✅ Gradio interface demo = gr.Interface( fn=deblur_image, inputs=gr.File(label="Input", type="filepath"), outputs=gr.Image(type="numpy", label="Deblurred Result"), title="Image Deblurring", description="Upload a blurry image.", css=custom_css ) # ✅ Launch Gradio app if __name__ == "__main__": demo.launch()