| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numbers |
| | from einops import rearrange |
| | import os |
| |
|
| |
|
| | def to_3d(x): |
| | return rearrange(x, 'b c h w -> b (h w) c') |
| |
|
| | def to_4d(x,h,w): |
| | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) |
| |
|
| | class BiasFree_LayerNorm(nn.Module): |
| | def __init__(self, normalized_shape): |
| | super(BiasFree_LayerNorm, self).__init__() |
| | if isinstance(normalized_shape, numbers.Integral): |
| | normalized_shape = (normalized_shape,) |
| | normalized_shape = torch.Size(normalized_shape) |
| |
|
| | assert len(normalized_shape) == 1 |
| |
|
| | self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| | self.normalized_shape = normalized_shape |
| |
|
| | def forward(self, x): |
| | sigma = x.var(-1, keepdim=True, unbiased=False) |
| | return x / torch.sqrt(sigma+1e-5) * self.weight |
| |
|
| | class WithBias_LayerNorm(nn.Module): |
| | def __init__(self, normalized_shape): |
| | super(WithBias_LayerNorm, self).__init__() |
| | if isinstance(normalized_shape, numbers.Integral): |
| | normalized_shape = (normalized_shape,) |
| | normalized_shape = torch.Size(normalized_shape) |
| |
|
| | assert len(normalized_shape) == 1 |
| |
|
| | self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| | self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| | self.normalized_shape = normalized_shape |
| |
|
| | def forward(self, x): |
| | mu = x.mean(-1, keepdim=True) |
| | sigma = x.var(-1, keepdim=True, unbiased=False) |
| | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | def __init__(self, dim, LayerNorm_type): |
| | super(LayerNorm, self).__init__() |
| | if LayerNorm_type =='BiasFree': |
| | self.body = BiasFree_LayerNorm(dim) |
| | else: |
| | self.body = WithBias_LayerNorm(dim) |
| |
|
| | def forward(self, x): |
| | h, w = x.shape[-2:] |
| | return to_4d(self.body(to_3d(x)), h, w) |
| |
|
| |
|
| |
|
| | |
| | |
| | class FeedForward(nn.Module): |
| | def __init__(self, dim, ffn_expansion_factor, bias): |
| | super(FeedForward, self).__init__() |
| |
|
| | hidden_features = int(dim*ffn_expansion_factor) |
| |
|
| | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) |
| |
|
| | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) |
| |
|
| | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) |
| |
|
| | def forward(self, x): |
| | x = self.project_in(x) |
| | x1, x2 = self.dwconv(x).chunk(2, dim=1) |
| | x = F.gelu(x1) * x2 |
| | x = self.project_out(x) |
| | return x |
| |
|
| |
|
| |
|
| | |
| | |
| | class Attention(nn.Module): |
| | def __init__(self, dim, num_heads, bias): |
| | super(Attention, self).__init__() |
| | self.num_heads = num_heads |
| | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) |
| |
|
| | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) |
| | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) |
| | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) |
| |
|
| |
|
| |
|
| | def forward(self, x): |
| | b,c,h,w = x.shape |
| |
|
| | qkv = self.qkv_dwconv(self.qkv(x)) |
| | q,k,v = qkv.chunk(3, dim=1) |
| |
|
| | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) |
| | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) |
| | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) |
| |
|
| | q = torch.nn.functional.normalize(q, dim=-1) |
| | k = torch.nn.functional.normalize(k, dim=-1) |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.temperature |
| | attn = attn.softmax(dim=-1) |
| |
|
| | out = (attn @ v) |
| |
|
| | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) |
| |
|
| | out = self.project_out(out) |
| | return out |
| |
|
| |
|
| |
|
| | |
| | class TransformerBlock(nn.Module): |
| | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): |
| | super(TransformerBlock, self).__init__() |
| |
|
| | self.norm1 = LayerNorm(dim, LayerNorm_type) |
| | self.attn = Attention(dim, num_heads, bias) |
| | self.norm2 = LayerNorm(dim, LayerNorm_type) |
| | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) |
| |
|
| | def forward(self, x): |
| | x = x + self.attn(self.norm1(x)) |
| | x = x + self.ffn(self.norm2(x)) |
| |
|
| | return x |
| |
|
| |
|
| |
|
| | |
| | |
| | class OverlapPatchEmbed(nn.Module): |
| | def __init__(self, in_c=3, embed_dim=48, bias=False): |
| | super(OverlapPatchEmbed, self).__init__() |
| |
|
| | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) |
| |
|
| | def forward(self, x): |
| | x = self.proj(x) |
| |
|
| | return x |
| |
|
| |
|
| |
|
| | |
| | |
| | class Downsample(nn.Module): |
| | def __init__(self, n_feat): |
| | super(Downsample, self).__init__() |
| |
|
| | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), |
| | nn.PixelUnshuffle(2)) |
| |
|
| | def forward(self, x): |
| | return self.body(x) |
| |
|
| | class Upsample(nn.Module): |
| | def __init__(self, n_feat): |
| | super(Upsample, self).__init__() |
| |
|
| | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), |
| | nn.PixelShuffle(2)) |
| |
|
| | def forward(self, x): |
| | return self.body(x) |
| |
|
| | |
| | |
| | class Restormer(nn.Module): |
| | def __init__(self, |
| | inp_channels=3, |
| | out_channels=3, |
| | dim = 48, |
| | num_blocks = [4,6,6,8], |
| | num_refinement_blocks = 4, |
| | heads = [1,2,4,8], |
| | ffn_expansion_factor = 2.66, |
| | bias = False, |
| | LayerNorm_type = 'WithBias', |
| | dual_pixel_task = False |
| | ): |
| |
|
| | super(Restormer, self).__init__() |
| |
|
| | self.patch_embed = OverlapPatchEmbed(inp_channels, dim) |
| |
|
| | self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) |
| |
|
| | self.down1_2 = Downsample(dim) |
| | self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) |
| |
|
| | self.down2_3 = Downsample(int(dim*2**1)) |
| | self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) |
| |
|
| | self.down3_4 = Downsample(int(dim*2**2)) |
| | self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) |
| |
|
| | self.up4_3 = Upsample(int(dim*2**3)) |
| | self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) |
| | self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) |
| |
|
| |
|
| | self.up3_2 = Upsample(int(dim*2**2)) |
| | self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) |
| | self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) |
| |
|
| | self.up2_1 = Upsample(int(dim*2**1)) |
| |
|
| | self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) |
| |
|
| | self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) |
| |
|
| | |
| | self.dual_pixel_task = dual_pixel_task |
| | if self.dual_pixel_task: |
| | self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) |
| | |
| |
|
| | self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) |
| |
|
| | def forward(self, inp_img): |
| |
|
| | inp_enc_level1 = self.patch_embed(inp_img) |
| | out_enc_level1 = self.encoder_level1(inp_enc_level1) |
| |
|
| | inp_enc_level2 = self.down1_2(out_enc_level1) |
| | out_enc_level2 = self.encoder_level2(inp_enc_level2) |
| |
|
| | inp_enc_level3 = self.down2_3(out_enc_level2) |
| | out_enc_level3 = self.encoder_level3(inp_enc_level3) |
| |
|
| | inp_enc_level4 = self.down3_4(out_enc_level3) |
| | latent = self.latent(inp_enc_level4) |
| |
|
| | inp_dec_level3 = self.up4_3(latent) |
| | inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) |
| | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) |
| | out_dec_level3 = self.decoder_level3(inp_dec_level3) |
| |
|
| | inp_dec_level2 = self.up3_2(out_dec_level3) |
| | inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) |
| | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) |
| | out_dec_level2 = self.decoder_level2(inp_dec_level2) |
| |
|
| | inp_dec_level1 = self.up2_1(out_dec_level2) |
| | inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) |
| | out_dec_level1 = self.decoder_level1(inp_dec_level1) |
| |
|
| | out_dec_level1 = self.refinement(out_dec_level1) |
| |
|
| | |
| | if self.dual_pixel_task: |
| | out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) |
| | out_dec_level1 = self.output(out_dec_level1) |
| | |
| | else: |
| | out_dec_level1 = self.output(out_dec_level1) + inp_img |
| |
|
| |
|
| | return out_dec_level1 |
| |
|
| | |
| |
|
| | |
| | from torchvision import transforms |
| | from PIL import Image |
| | import gradio as gr |
| | |
| |
|
| | |
| | print("Setting device...") |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | print(f"Using device: {device}") |
| |
|
| | |
| | model_path = 'restormer_finetuned.pth' |
| |
|
| | |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"Model weights not found at '{model_path}'. " |
| | "Make sure the file is uploaded to the Hugging Face Space repository.") |
| |
|
| | print(f"Loading model from: {model_path}") |
| | |
| | model = Restormer() |
| |
|
| | try: |
| | checkpoint = torch.load(model_path, map_location=device) |
| | |
| | |
| | |
| | model.load_state_dict(checkpoint) |
| | print("Model weights loaded successfully.") |
| | except Exception as e: |
| | print(f"Error loading model weights: {e}") |
| | print("Please ensure 'restormer_finetuned.pth' is the correct file and contains the model's state_dict.") |
| | |
| | raise e |
| |
|
| | model.to(device) |
| | model.eval() |
| | print("Model loaded and set to evaluation mode.") |
| | |
| |
|
| | |
| | def deblur_image(uploaded_image): |
| | if uploaded_image is None: |
| | return None |
| |
|
| | print("Processing image...") |
| | try: |
| | image = uploaded_image.convert("RGB") |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.ToTensor(), |
| | |
| | |
| | ]) |
| | input_tensor = transform(image).unsqueeze(0).to(device) |
| |
|
| | with torch.no_grad(): |
| | output_tensor = model(input_tensor) |
| |
|
| | |
| | output_tensor = torch.clamp(output_tensor, 0, 1) |
| |
|
| | output_image = transforms.ToPILImage()(output_tensor.squeeze(0).cpu()) |
| | print("Image processed successfully.") |
| | return output_image |
| | except Exception as e: |
| | print(f"Error during image processing: {e}") |
| | |
| | raise gr.Error(f"Failed to process image: {e}") |
| |
|
| |
|
| | |
| | print("Launching Gradio interface...") |
| |
|
| | |
| | custom_css = """ |
| | .gradio-container { |
| | font-family: 'IBM Plex Sans', sans-serif; |
| | } |
| | .container { |
| | max-width: 800px; |
| | margin: auto; |
| | padding-top: 1.5rem; |
| | } |
| | #component-0 { |
| | max-width: 800px; |
| | margin: auto; |
| | } |
| | .gradio-interface { |
| | padding: 1.5rem; |
| | border-radius: 8px; |
| | background: white; |
| | box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
| | } |
| | .gradio-interface .gradio-interface-header { |
| | text-align: center; |
| | margin-bottom: 2rem; |
| | } |
| | .gradio-interface .gradio-interface-header h1 { |
| | font-size: 2.5rem; |
| | font-weight: 600; |
| | color: #2d3748; |
| | margin-bottom: 1rem; |
| | } |
| | .gradio-interface .gradio-interface-header p { |
| | font-size: 1.1rem; |
| | color: #4a5568; |
| | max-width: 600px; |
| | margin: auto; |
| | } |
| | .gradio-interface .gradio-interface-content { |
| | display: flex; |
| | flex-direction: column; |
| | gap: 2rem; |
| | } |
| | .gradio-interface .gradio-interface-content .gradio-interface-input, |
| | .gradio-interface .gradio-interface-content .gradio-interface-output { |
| | background: #f7fafc; |
| | padding: 1.5rem; |
| | border-radius: 8px; |
| | border: 1px solid #e2e8f0; |
| | } |
| | .gradio-interface .gradio-interface-content .gradio-interface-input label, |
| | .gradio-interface .gradio-interface-content .gradio-interface-output label { |
| | font-size: 1.2rem; |
| | font-weight: 500; |
| | color: #2d3748; |
| | margin-bottom: 0.5rem; |
| | } |
| | .gradio-interface .gradio-interface-content button { |
| | background: #4299e1; |
| | color: white; |
| | padding: 0.75rem 1.5rem; |
| | border-radius: 6px; |
| | font-weight: 500; |
| | transition: all 0.2s; |
| | } |
| | .gradio-interface .gradio-interface-content button:hover { |
| | background: #3182ce; |
| | transform: translateY(-1px); |
| | } |
| | """ |
| |
|
| | |
| | with gr.Blocks(css=custom_css) as iface: |
| | with gr.Column(elem_classes=["container"]): |
| | gr.Markdown(""" |
| | # 🎨 DeBlurred |
| | |
| | Transform your blurry images into crystal clear masterpieces. |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_image = gr.Image( |
| | type="pil", |
| | label="Upload Blurry Image", |
| | elem_id="input-image" |
| | ) |
| | submit_btn = gr.Button("✨ Deblur Image", variant="primary") |
| | |
| | with gr.Column(): |
| | output_image = gr.Image( |
| | type="pil", |
| | label="Deblurred Result", |
| | elem_id="output-image" |
| | ) |
| | |
| | gr.Markdown(""" |
| | ### 📝 Instructions |
| | 1. Upload a blurry image using the upload button or drag and drop |
| | 2. Click the "Deblur Image" button |
| | 3. Wait for the Model to process your image |
| | 4. Download the deblurred result |
| | |
| | ### ℹ️ About |
| | Developed as part of a coursework project, the app uses image processing techniques to reduce blur and restore clarity. It's designed to be simple, fast, and accessible to users who want quick improvements without needing advanced editing skills. |
| | """) |
| | |
| | |
| | clear_btn = gr.Button("Clear") |
| | clear_btn.click(fn=lambda: (None, None), inputs=None, outputs=[input_image, output_image]) |
| | |
| | |
| | image_files = [f"blur{i}.jpg" for i in range(1, 6) if os.path.exists(f"blur{i}.jpg")] |
| | examples = [[img] for img in image_files] if image_files else None |
| | gr.Examples(examples=examples, inputs=input_image, outputs=output_image, fn=deblur_image, cache_examples=True) |
| | |
| | |
| | submit_btn.click( |
| | fn=deblur_image, |
| | inputs=input_image, |
| | outputs=output_image |
| | ) |
| |
|
| | |
| | iface.launch( |
| | share=True, |
| | show_error=True, |
| | ) |
| | print("Gradio interface launched.") |
| | |