Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import torch.nn.functional as F | |
| class Vgg19(torch.nn.Module): | |
| def __init__(self, requires_grad=False): | |
| super(Vgg19, self).__init__() | |
| vgg_pretrained_features = vgg19(pretrained=True).features | |
| self.slice1 = torch.nn.Sequential() | |
| for x in range(12): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x].eval()) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, X): | |
| h_relu1 = self.slice1(X) | |
| return h_relu1 | |
| class ContrastLoss(nn.Module): | |
| def __init__(self, ablation=False): | |
| super(ContrastLoss, self).__init__() | |
| self.vgg = Vgg19().cuda() | |
| self.l1 = nn.L1Loss() | |
| self.ab = ablation | |
| self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear') | |
| def forward(self, restore, sharp, blur): | |
| B, C, H, W = restore.size() | |
| restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur) | |
| # filter out sharp regions | |
| threshold = 0.01 | |
| mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W) | |
| mask[mask <= threshold] = 0 | |
| mask[mask > threshold] = 1 | |
| mask = self.down_sample_4(mask) | |
| d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4) | |
| d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4) | |
| mask_size = torch.sum(mask) | |
| contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size | |
| return contrastive | |
| class ContrastLoss_Ori(nn.Module): | |
| def __init__(self, ablation=False): | |
| super(ContrastLoss_Ori, self).__init__() | |
| self.vgg = Vgg19().cuda() | |
| self.l1 = nn.L1Loss() | |
| self.ab = ablation | |
| def forward(self, restore, sharp, blur): | |
| restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur) | |
| d_ap = self.l1(restore_vgg, sharp_vgg.detach()) | |
| d_an = self.l1(restore_vgg, blur_vgg.detach()) | |
| contrastive_loss = d_ap / (d_an + 1e-7) | |
| return contrastive_loss | |
| class CharbonnierLoss(nn.Module): | |
| """Charbonnier Loss (L1)""" | |
| def __init__(self, eps=1e-3): | |
| super(CharbonnierLoss, self).__init__() | |
| self.eps = eps | |
| def forward(self, x, y): | |
| diff = x - y | |
| # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) | |
| loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps))) | |
| return loss | |
| class EdgeLoss(nn.Module): | |
| def __init__(self): | |
| super(EdgeLoss, self).__init__() | |
| k = torch.Tensor([[.05, .25, .4, .25, .05]]) | |
| self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1) | |
| if torch.cuda.is_available(): | |
| self.kernel = self.kernel.cuda() | |
| self.loss = CharbonnierLoss() | |
| def conv_gauss(self, img): | |
| n_channels, _, kw, kh = self.kernel.shape | |
| img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate') | |
| return F.conv2d(img, self.kernel, groups=n_channels) | |
| def laplacian_kernel(self, current): | |
| filtered = self.conv_gauss(current) # filter | |
| down = filtered[:, :, ::2, ::2] # downsample | |
| new_filter = torch.zeros_like(filtered) | |
| new_filter[:, :, ::2, ::2] = down * 4 # upsample | |
| filtered = self.conv_gauss(new_filter) # filter | |
| diff = current - filtered | |
| return diff | |
| def forward(self, x, y): | |
| # x = torch.clamp(x + 0.5, min = 0,max = 1) | |
| # y = torch.clamp(y + 0.5, min = 0,max = 1) | |
| loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) | |
| return loss | |
| class Stripformer_Loss(nn.Module): | |
| def __init__(self, ): | |
| super(Stripformer_Loss, self).__init__() | |
| self.char = CharbonnierLoss() | |
| self.edge = EdgeLoss() | |
| self.contrastive = ContrastLoss() | |
| def forward(self, restore, sharp, blur): | |
| char = self.char(restore, sharp) | |
| edge = 0.05 * self.edge(restore, sharp) | |
| contrastive = 0.0005 * self.contrastive(restore, sharp, blur) | |
| loss = char + edge + contrastive | |
| return loss | |
| def get_loss(model): | |
| if model['content_loss'] == 'Stripformer_Loss': | |
| content_loss = Stripformer_Loss() | |
| elif model['content_loss'] == 'CharbonnierLoss': | |
| content_loss = CharbonnierLoss() | |
| else: | |
| raise ValueError("ContentLoss [%s] not recognized." % model['content_loss']) | |
| return content_loss | |
| from typing import Union, List, Dict, Any, cast | |
| import torch | |
| import torch.nn as nn | |
| class VGG(nn.Module): | |
| def __init__( | |
| self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 | |
| ) -> None: | |
| super().__init__() | |
| self.features = features | |
| self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512 * 7 * 7, 4096), | |
| nn.ReLU(True), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(4096, 4096), | |
| nn.ReLU(True), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(4096, num_classes), | |
| ) | |
| if init_weights: | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, 0, 0.01) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.features(x) | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.classifier(x) | |
| return x | |
| def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: | |
| layers: List[nn.Module] = [] | |
| in_channels = 3 | |
| for v in cfg: | |
| if v == "M": | |
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |
| else: | |
| v = cast(int, v) | |
| conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) | |
| if batch_norm: | |
| layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] | |
| else: | |
| layers += [conv2d, nn.ReLU(inplace=True)] | |
| in_channels = v | |
| return nn.Sequential(*layers) | |
| cfgs: Dict[str, List[Union[str, int]]] = { | |
| "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], | |
| "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], | |
| "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], | |
| "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], | |
| } | |
| def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: | |
| if pretrained: | |
| kwargs["init_weights"] = False | |
| model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) | |
| if pretrained: | |
| state_dict = torch.load("/home/hanzhou1996/low-level/StripMamba/models/vgg19-dcbb9e9d.pth") # change the path to vgg19.pth | |
| model.load_state_dict(state_dict) | |
| return model | |
| def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: | |
| r"""VGG 19-layer model (configuration "E") | |
| `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. | |
| The required minimum input size of the model is 32x32. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| progress (bool): If True, displays a progress bar of the download to stderr | |
| """ | |
| return _vgg("vgg19", "E", False, pretrained, progress, **kwargs) | |
| """ | |
| if __name__ == "__main__": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| #model = VGG(make_layers(cfgs["E"], batch_norm=False)).to(device) | |
| #model.load_state_dict(torch.load("models/vgg19-dcbb9e9d.pth")) | |
| model = vgg19().to(device) | |
| print(model.features) | |
| BATCH_SIZE = 3 | |
| x = torch.randn(3, 3, 224, 224).to(device) | |
| assert model(x).shape == torch.Size([BATCH_SIZE, 1000]) | |
| print(model(x).shape) | |
| """ |