DeBlurred / app.py
BAbhijit's picture
Update app.py
104249b verified
# Block 1: Imports and Model Definition ==========================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
from einops import rearrange
import os # For checking model file path
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x,h,w):
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma+1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type =='BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
##########################################################################
## Resizing modules
class Downsample(nn.Module):
def __init__(self, n_feat):
super(Downsample, self).__init__()
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelUnshuffle(2))
def forward(self, x):
return self.body(x)
class Upsample(nn.Module):
def __init__(self, n_feat):
super(Upsample, self).__init__()
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelShuffle(2))
def forward(self, x):
return self.body(x)
##########################################################################
##---------- Restormer -----------------------
class Restormer(nn.Module):
def __init__(self,
inp_channels=3,
out_channels=3,
dim = 48,
num_blocks = [4,6,6,8],
num_refinement_blocks = 4,
heads = [1,2,4,8],
ffn_expansion_factor = 2.66,
bias = False,
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
):
super(Restormer, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
#### For Dual-Pixel Defocus Deblurring Task ####
self.dual_pixel_task = dual_pixel_task
if self.dual_pixel_task:
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
###########################
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, inp_img):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
inp_enc_level2 = self.down1_2(out_enc_level1)
out_enc_level2 = self.encoder_level2(inp_enc_level2)
inp_enc_level3 = self.down2_3(out_enc_level2)
out_enc_level3 = self.encoder_level3(inp_enc_level3)
inp_enc_level4 = self.down3_4(out_enc_level3)
latent = self.latent(inp_enc_level4)
inp_dec_level3 = self.up4_3(latent)
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
out_dec_level3 = self.decoder_level3(inp_dec_level3)
inp_dec_level2 = self.up3_2(out_dec_level3)
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
out_dec_level2 = self.decoder_level2(inp_dec_level2)
inp_dec_level1 = self.up2_1(out_dec_level2)
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
out_dec_level1 = self.decoder_level1(inp_dec_level1)
out_dec_level1 = self.refinement(out_dec_level1)
#### For Dual-Pixel Defocus Deblurring Task ####
if self.dual_pixel_task:
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
out_dec_level1 = self.output(out_dec_level1)
###########################
else:
out_dec_level1 = self.output(out_dec_level1) + inp_img
return out_dec_level1
# =================================================================
# Block 1: Additional Imports =====================================
from torchvision import transforms
from PIL import Image
import gradio as gr
# =================================================================
# Block 2: Model Loading ==========================================
print("Setting device...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# --- Ensure 'restormer_finetuned.pth' is in the Space repo root ---
model_path = 'restormer_finetuned.pth'
# Check if the model file exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model weights not found at '{model_path}'. "
"Make sure the file is uploaded to the Hugging Face Space repository.")
print(f"Loading model from: {model_path}")
# Instantiate the model (adjust args if needed)
model = Restormer()
try:
checkpoint = torch.load(model_path, map_location=device)
# *** IMPORTANT: Check your .pth file structure ***
# If the state_dict is nested, use: model.load_state_dict(checkpoint['state_dict'])
# If the .pth file *is* the state_dict, use:
model.load_state_dict(checkpoint)
print("Model weights loaded successfully.")
except Exception as e:
print(f"Error loading model weights: {e}")
print("Please ensure 'restormer_finetuned.pth' is the correct file and contains the model's state_dict.")
# Optional: re-raise the error to stop the app if loading fails
raise e
model.to(device)
model.eval()
print("Model loaded and set to evaluation mode.")
# =================================================================
# Block 3: Inference Function =====================================
def deblur_image(uploaded_image):
if uploaded_image is None:
return None # Or maybe return a placeholder image/message
print("Processing image...")
try:
image = uploaded_image.convert("RGB")
# Define transforms (Consider adding normalization if needed by the model)
transform = transforms.Compose([
transforms.Resize((256, 256)), # Fixed size - potential limitation
transforms.ToTensor(),
# Example normalization (if needed):
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0).to(device) # [1, C, H, W]
with torch.no_grad():
output_tensor = model(input_tensor)
# Clamp output tensor values to [0, 1] range before converting to PIL
output_tensor = torch.clamp(output_tensor, 0, 1)
output_image = transforms.ToPILImage()(output_tensor.squeeze(0).cpu())
print("Image processed successfully.")
return output_image
except Exception as e:
print(f"Error during image processing: {e}")
# Provide feedback to the user in the UI
raise gr.Error(f"Failed to process image: {e}") # Shows error in Gradio UI
# Block 4: Gradio Interface Launch ================================
print("Launching Gradio interface...")
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.container {
max-width: 800px;
margin: auto;
padding-top: 1.5rem;
}
#component-0 {
max-width: 800px;
margin: auto;
}
.gradio-interface {
padding: 1.5rem;
border-radius: 8px;
background: white;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.gradio-interface .gradio-interface-header {
text-align: center;
margin-bottom: 2rem;
}
.gradio-interface .gradio-interface-header h1 {
font-size: 2.5rem;
font-weight: 600;
color: #2d3748;
margin-bottom: 1rem;
}
.gradio-interface .gradio-interface-header p {
font-size: 1.1rem;
color: #4a5568;
max-width: 600px;
margin: auto;
}
.gradio-interface .gradio-interface-content {
display: flex;
flex-direction: column;
gap: 2rem;
}
.gradio-interface .gradio-interface-content .gradio-interface-input,
.gradio-interface .gradio-interface-content .gradio-interface-output {
background: #f7fafc;
padding: 1.5rem;
border-radius: 8px;
border: 1px solid #e2e8f0;
}
.gradio-interface .gradio-interface-content .gradio-interface-input label,
.gradio-interface .gradio-interface-content .gradio-interface-output label {
font-size: 1.2rem;
font-weight: 500;
color: #2d3748;
margin-bottom: 0.5rem;
}
.gradio-interface .gradio-interface-content button {
background: #4299e1;
color: white;
padding: 0.75rem 1.5rem;
border-radius: 6px;
font-weight: 500;
transition: all 0.2s;
}
.gradio-interface .gradio-interface-content button:hover {
background: #3182ce;
transform: translateY(-1px);
}
"""
# Create a more sophisticated interface using Blocks
with gr.Blocks(css=custom_css) as iface:
with gr.Column(elem_classes=["container"]):
gr.Markdown("""
# 🎨 DeBlurred
Transform your blurry images into crystal clear masterpieces.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload Blurry Image",
elem_id="input-image"
)
submit_btn = gr.Button("✨ Deblur Image", variant="primary")
with gr.Column():
output_image = gr.Image(
type="pil",
label="Deblurred Result",
elem_id="output-image"
)
gr.Markdown("""
### 📝 Instructions
1. Upload a blurry image using the upload button or drag and drop
2. Click the "Deblur Image" button
3. Wait for the Model to process your image
4. Download the deblurred result
### ℹ️ About
Developed as part of a coursework project, the app uses image processing techniques to reduce blur and restore clarity. It's designed to be simple, fast, and accessible to users who want quick improvements without needing advanced editing skills.
""")
# Add a clear button to reset the input and output images
clear_btn = gr.Button("Clear")
clear_btn.click(fn=lambda: (None, None), inputs=None, outputs=[input_image, output_image])
#Optionally, you can add examples later if needed
image_files = [f"blur{i}.jpg" for i in range(1, 6) if os.path.exists(f"blur{i}.jpg")]
examples = [[img] for img in image_files] if image_files else None
gr.Examples(examples=examples, inputs=input_image, outputs=output_image, fn=deblur_image, cache_examples=True)
# Connect the button click to the function
submit_btn.click(
fn=deblur_image,
inputs=input_image,
outputs=output_image
)
# Launch the interface
iface.launch(
share=True, # Create a public link
show_error=True, # Show detailed error messages
)
print("Gradio interface launched.")
# =================================================================