Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as tfs | |
| import os | |
| def default_conv(in_channels, out_channels, kernel_size, bias=True): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) | |
| class PALayer(nn.Module): | |
| def __init__(self, channel): | |
| super(PALayer, self).__init__() | |
| self.pa = nn.Sequential( | |
| nn.Conv2d(channel, channel // 8, 1, bias=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channel // 8, 1, 1, bias=True), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| y = self.pa(x) | |
| return x * y | |
| class CALayer(nn.Module): | |
| def __init__(self, channel): | |
| super(CALayer, self).__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.ca = nn.Sequential( | |
| nn.Conv2d(channel, channel // 8, 1, bias=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channel // 8, channel, 1, bias=True), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| y = self.avg_pool(x) | |
| y = self.ca(y) | |
| return x * y | |
| class Block(nn.Module): | |
| def __init__(self, conv, dim, kernel_size): | |
| super(Block, self).__init__() | |
| self.conv1 = conv(dim, dim, kernel_size, bias=True) | |
| self.act1 = nn.ReLU(inplace=True) | |
| self.conv2 = conv(dim, dim, kernel_size, bias=True) | |
| self.calayer = CALayer(dim) | |
| self.palayer = PALayer(dim) | |
| def forward(self, x): | |
| res = self.act1(self.conv1(x)) | |
| res = res + x | |
| res = self.conv2(res) | |
| res = self.calayer(res) | |
| res = self.palayer(res) | |
| res += x | |
| return res | |
| class Group(nn.Module): | |
| def __init__(self, conv, dim, kernel_size, blocks): | |
| super(Group, self).__init__() | |
| modules = [Block(conv, dim, kernel_size) for _ in range(blocks)] | |
| modules.append(conv(dim, dim, kernel_size)) | |
| self.gp = nn.Sequential(*modules) | |
| def forward(self, x): | |
| res = self.gp(x) | |
| res += x | |
| return res | |
| class FFA(nn.Module): | |
| def __init__(self, gps, blocks, conv=default_conv): | |
| super(FFA, self).__init__() | |
| self.gps = gps | |
| self.dim = 64 | |
| kernel_size = 3 | |
| pre_process = [conv(3, self.dim, kernel_size)] | |
| assert self.gps == 3 | |
| self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks) | |
| self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks) | |
| self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks) | |
| self.ca = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, bias=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, bias=True), | |
| nn.Sigmoid() | |
| ) | |
| self.palayer = PALayer(self.dim) | |
| post_process = [ | |
| conv(self.dim, self.dim, kernel_size), | |
| conv(self.dim, 3, kernel_size) | |
| ] | |
| self.pre = nn.Sequential(*pre_process) | |
| self.post = nn.Sequential(*post_process) | |
| def forward(self, x1): | |
| x = self.pre(x1) | |
| res1 = self.g1(x) | |
| res2 = self.g2(res1) | |
| res3 = self.g3(res2) | |
| w = self.ca(torch.cat([res1, res2, res3], dim=1)) | |
| w = w.view(-1, self.gps, self.dim)[:, :, :, None, None] | |
| out = w[:, 0, :, :, :] * res1 + w[:, 1, :, :, :] * res2 + w[:, 2, :, :, :] * res3 | |
| out = self.palayer(out) | |
| x = self.post(out) | |
| return x + x1 | |
| MODEL_PATH = 'tti.pk' | |
| gps = 3 | |
| blocks = 19 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| net = FFA(gps=gps, blocks=blocks).to(device) | |
| net = torch.nn.DataParallel(net) | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model checkpoint not found at {MODEL_PATH}") | |
| try: | |
| torch.serialization.add_safe_globals([np.core.multiarray.scalar]) | |
| checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True) | |
| except: | |
| print("Warning: Loading checkpoint with weights_only=False. Ensure the checkpoint is from a trusted source.") | |
| checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False) | |
| net.load_state_dict(checkpoint['model']) | |
| net.eval() | |
| print(f"Model loaded successfully on {device}") | |
| def dehaze_image(image): | |
| """ | |
| Process a hazy image and return the dehazed result. | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| PIL Image: Dehazed image | |
| """ | |
| try: | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| haze_img = image.convert("RGB") | |
| transform = tfs.Compose([ | |
| tfs.ToTensor(), | |
| tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152]) | |
| ]) | |
| haze_tensor = transform(haze_img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| pred = net(haze_tensor) | |
| pred_clamped = pred.clamp(0, 1).cpu() | |
| pred_numpy = pred_clamped.squeeze(0).permute(1, 2, 0).numpy() | |
| pred_numpy = (pred_numpy * 255).astype(np.uint8) | |
| return Image.fromarray(pred_numpy) | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| return None | |
| SAMPLE_IMAGES = [ | |
| "./img/s2.png", | |
| "./img/s4.png" | |
| ] | |
| def load_sample_image(sample_path): | |
| """Load and return a sample image""" | |
| try: | |
| if os.path.exists(sample_path): | |
| return Image.open(sample_path) | |
| else: | |
| print(f"Sample image not found: {sample_path}") | |
| return None | |
| except Exception as e: | |
| print(f"Error loading sample image {sample_path}: {e}") | |
| return None | |
| def create_interface(): | |
| with gr.Blocks(title="Image Dehazing App", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🌫️ Image Dehazing with FFA-Net") | |
| gr.Markdown("Upload a hazy image to remove fog, haze, and improve visibility!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Upload Hazy Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| gr.Markdown("### Try Sample Images") | |
| with gr.Row(): | |
| sample1_btn = gr.Image( | |
| value=load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None, | |
| label="Sample 1", | |
| interactive=True, | |
| width=150, | |
| height=150, | |
| container=True, | |
| show_download_button=False | |
| ) | |
| sample2_btn = gr.Image( | |
| value=load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None, | |
| label="Sample 2", | |
| interactive=True, | |
| width=150, | |
| height=150, | |
| container=True, | |
| show_download_button=False | |
| ) | |
| process_btn = gr.Button( | |
| "Remove Haze ✨", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Dehazed Result", | |
| type="pil", | |
| height=400 | |
| ) | |
| def use_sample1(): | |
| return load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None | |
| def use_sample2(): | |
| return load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None | |
| sample1_btn.select( | |
| fn=use_sample1, | |
| outputs=input_image | |
| ) | |
| sample2_btn.select( | |
| fn=use_sample2, | |
| outputs=input_image | |
| ) | |
| process_btn.click( | |
| fn=dehaze_image, | |
| inputs=input_image, | |
| outputs=output_image, | |
| api_name="dehaze" | |
| ) | |
| input_image.change( | |
| fn=dehaze_image, | |
| inputs=input_image, | |
| outputs=output_image | |
| ) | |
| gr.Markdown(""" | |
| ### About | |
| This app uses the FFA-Net (Feature Fusion Attention Network) for single image dehazing. | |
| The model removes atmospheric haze and fog to restore clear, vibrant images. | |
| **Tips for best results:** | |
| - Use good quality images with visible haze or fog | |
| - Model works best on indoor images | |
| **Made by <a href="https://www.linkedin.com/in/aditsg26/">Aditya Singh</a> and <a href="https://www.linkedin.com/in/ramandeep-singh-makkar/">Ramandeep Singh Makkar</a>** | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False | |
| ) | |