XYScanNet_Demo / models /losses.py
HanzhouLiu
Add application file
b56342d
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
for x in range(12):
self.slice1.add_module(str(x), vgg_pretrained_features[x].eval())
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
return h_relu1
class ContrastLoss(nn.Module):
def __init__(self, ablation=False):
super(ContrastLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.l1 = nn.L1Loss()
self.ab = ablation
self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear')
def forward(self, restore, sharp, blur):
B, C, H, W = restore.size()
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
# filter out sharp regions
threshold = 0.01
mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
mask = self.down_sample_4(mask)
d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
mask_size = torch.sum(mask)
contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size
return contrastive
class ContrastLoss_Ori(nn.Module):
def __init__(self, ablation=False):
super(ContrastLoss_Ori, self).__init__()
self.vgg = Vgg19().cuda()
self.l1 = nn.L1Loss()
self.ab = ablation
def forward(self, restore, sharp, blur):
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
d_ap = self.l1(restore_vgg, sharp_vgg.detach())
d_an = self.l1(restore_vgg, blur_vgg.detach())
contrastive_loss = d_ap / (d_an + 1e-7)
return contrastive_loss
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
return loss
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
k = torch.Tensor([[.05, .25, .4, .25, .05]])
self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
if torch.cuda.is_available():
self.kernel = self.kernel.cuda()
self.loss = CharbonnierLoss()
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:, :, ::2, ::2] # downsample
new_filter = torch.zeros_like(filtered)
new_filter[:, :, ::2, ::2] = down * 4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def forward(self, x, y):
# x = torch.clamp(x + 0.5, min = 0,max = 1)
# y = torch.clamp(y + 0.5, min = 0,max = 1)
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
class Stripformer_Loss(nn.Module):
def __init__(self, ):
super(Stripformer_Loss, self).__init__()
self.char = CharbonnierLoss()
self.edge = EdgeLoss()
self.contrastive = ContrastLoss()
def forward(self, restore, sharp, blur):
char = self.char(restore, sharp)
edge = 0.05 * self.edge(restore, sharp)
contrastive = 0.0005 * self.contrastive(restore, sharp, blur)
loss = char + edge + contrastive
return loss
def get_loss(model):
if model['content_loss'] == 'Stripformer_Loss':
content_loss = Stripformer_Loss()
elif model['content_loss'] == 'CharbonnierLoss':
content_loss = CharbonnierLoss()
else:
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
return content_loss
from typing import Union, List, Dict, Any, cast
import torch
import torch.nn as nn
class VGG(nn.Module):
def __init__(
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
) -> None:
super().__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(p=dropout),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=dropout),
nn.Linear(4096, num_classes),
)
if init_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
layers: List[nn.Module] = []
in_channels = 3
for v in cfg:
if v == "M":
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
v = cast(int, v)
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfgs: Dict[str, List[Union[str, int]]] = {
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
"E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = torch.load("/home/hanzhou1996/low-level/StripMamba/models/vgg19-dcbb9e9d.pth") # change the path to vgg19.pth
model.load_state_dict(state_dict)
return model
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
"""
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
#model = VGG(make_layers(cfgs["E"], batch_norm=False)).to(device)
#model.load_state_dict(torch.load("models/vgg19-dcbb9e9d.pth"))
model = vgg19().to(device)
print(model.features)
BATCH_SIZE = 3
x = torch.randn(3, 3, 224, 224).to(device)
assert model(x).shape == torch.Size([BATCH_SIZE, 1000])
print(model(x).shape)
"""