dftest1 / src /features /drunet /drunet_model.py
akcanca's picture
Upload 110 files (#1)
07fe054 verified
"""
DRUNet (Denoising Residual UNet) model implementation.
Based on: "Plug-and-Play Image Restoration with Deep Denoiser Prior"
https://github.com/cszn/DPIR
This implementation matches the pretrained weights structure exactly.
"""
import torch
import torch.nn as nn
class ResBlock(nn.Module):
"""Residual block with two 3x3 convolutions and ReLU (no bias)."""
def __init__(self, channels):
super(ResBlock, self).__init__()
self.res = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
)
def forward(self, x):
return x + self.res(x)
class UNetRes(nn.Module):
"""
UNet with residual blocks - matches DPIR/KAIR pretrained weights.
Architecture:
- m_head: Conv 3x3, in_nc -> 64
- m_down1: 4x ResBlock(64) + StrideConv(64->128)
- m_down2: 4x ResBlock(128) + StrideConv(128->256)
- m_down3: 4x ResBlock(256) + StrideConv(256->512)
- m_body: 4x ResBlock(512)
- m_up3: ConvTranspose(512->256) + 4x ResBlock(256)
- m_up2: ConvTranspose(256->128) + 4x ResBlock(128)
- m_up1: ConvTranspose(128->64) + 4x ResBlock(64)
- m_tail: Conv 3x3, 64 -> out_nc
Note: All convolutions use bias=False to match pretrained weights.
"""
def __init__(self, in_nc=2, out_nc=1, nc=[64, 128, 256, 512], nb=4):
super(UNetRes, self).__init__()
# Head
self.m_head = nn.Conv2d(in_nc, nc[0], 3, padding=1, bias=False)
# Encoder
self.m_down1 = nn.Sequential(
*[ResBlock(nc[0]) for _ in range(nb)],
nn.Conv2d(nc[0], nc[1], 2, stride=2, bias=False) # Downsample
)
self.m_down2 = nn.Sequential(
*[ResBlock(nc[1]) for _ in range(nb)],
nn.Conv2d(nc[1], nc[2], 2, stride=2, bias=False)
)
self.m_down3 = nn.Sequential(
*[ResBlock(nc[2]) for _ in range(nb)],
nn.Conv2d(nc[2], nc[3], 2, stride=2, bias=False)
)
# Bottleneck
self.m_body = nn.Sequential(
*[ResBlock(nc[3]) for _ in range(nb)]
)
# Decoder
self.m_up3 = nn.Sequential(
nn.ConvTranspose2d(nc[3], nc[2], 2, stride=2, bias=False),
*[ResBlock(nc[2]) for _ in range(nb)]
)
self.m_up2 = nn.Sequential(
nn.ConvTranspose2d(nc[2], nc[1], 2, stride=2, bias=False),
*[ResBlock(nc[1]) for _ in range(nb)]
)
self.m_up1 = nn.Sequential(
nn.ConvTranspose2d(nc[1], nc[0], 2, stride=2, bias=False),
*[ResBlock(nc[0]) for _ in range(nb)]
)
# Tail
self.m_tail = nn.Conv2d(nc[0], out_nc, 3, padding=1, bias=False)
def forward(self, x):
# Pad to multiple of 8
h, w = x.shape[2], x.shape[3]
pad_h = (8 - h % 8) % 8
pad_w = (8 - w % 8) % 8
if pad_h > 0 or pad_w > 0:
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x1 = self.m_head(x)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
# Remove padding
if pad_h > 0 or pad_w > 0:
x = x[:, :, :h, :w]
return x
class DRUNet(nn.Module):
"""
DRUNet wrapper that handles noise level map concatenation.
For forensic residual extraction, we use a fixed noise level
since we want consistent denoising behavior.
"""
def __init__(self, in_nc=1, out_nc=1, noise_level=15):
"""
Args:
in_nc: Input channels (1 for grayscale)
out_nc: Output channels (1 for grayscale)
noise_level: Fixed noise level (0-255 scale)
"""
super(DRUNet, self).__init__()
# UNet takes image + noise map
self.unet = UNetRes(in_nc=in_nc + 1, out_nc=out_nc)
self.noise_level = noise_level / 255.0
def forward(self, x, noise_level=None):
"""
Args:
x: Input image tensor (B, C, H, W) in [0, 1]
noise_level: Optional noise level override (0-1 scale)
Returns:
Denoised image tensor (B, C, H, W) in [0, 1]
"""
if noise_level is None:
noise_level = self.noise_level
b, c, h, w = x.shape
noise_map = torch.full((b, 1, h, w), noise_level, device=x.device, dtype=x.dtype)
x_in = torch.cat([x, noise_map], dim=1)
return self.unet(x_in)
def load_drunet_gray(weights_path, noise_level=15, device=None):
"""
Load pretrained DRUNet grayscale model.
Args:
weights_path: Path to drunet_gray.pth
noise_level: Noise level for denoising (0-255 scale)
device: torch device
Returns:
Loaded DRUNet model in eval mode
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DRUNet(in_nc=1, out_nc=1, noise_level=noise_level)
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
model.unet.load_state_dict(state_dict)
model.to(device)
model.eval()
return model