luisabwk commited on
Commit
8269d7b
·
verified ·
1 Parent(s): 0fb5e8c

Delete ffdnet_model.py

Browse files
Files changed (1) hide show
  1. ffdnet_model.py +0 -50
ffdnet_model.py DELETED
@@ -1,50 +0,0 @@
1
- """FFDNet denoiser (Zhang et al., 2018 — https://arxiv.org/abs/1710.04026).
2
-
3
- A arquitetura espelha exatamente a implementação oficial do KAIR
4
- (https://github.com/cszn/KAIR, arquivos `models/network_ffdnet.py` +
5
- `models/basicblock.py`) para que os pesos pré-treinados do release v1.0
6
- carreguem com `strict=True`, sem rename de chaves.
7
-
8
- Convs são empilhadas num único `nn.Sequential` (sem blocos aninhados),
9
- reproduzindo o comportamento do `basicblock.sequential(...)` do KAIR que
10
- achata filhos de `nn.Sequential`.
11
-
12
- Configurações oficiais:
13
- - color: in_nc=3, out_nc=3, nc=96, nb=12 (ffdnet_color.pth)
14
- - gray: in_nc=1, out_nc=1, nc=64, nb=15 (ffdnet_gray.pth)
15
- """
16
- from __future__ import annotations
17
-
18
- import math
19
-
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
-
24
-
25
- class FFDNet(nn.Module):
26
- def __init__(self, in_nc: int = 3, out_nc: int = 3, nc: int = 96, nb: int = 12):
27
- super().__init__()
28
- self.sf = 2
29
-
30
- layers: list[nn.Module] = []
31
- layers.append(nn.Conv2d(in_nc * self.sf * self.sf + 1, nc, 3, 1, 1, bias=True))
32
- layers.append(nn.ReLU(inplace=True))
33
- for _ in range(nb - 2):
34
- layers.append(nn.Conv2d(nc, nc, 3, 1, 1, bias=True))
35
- layers.append(nn.ReLU(inplace=True))
36
- layers.append(nn.Conv2d(nc, out_nc * self.sf * self.sf, 3, 1, 1, bias=True))
37
- self.model = nn.Sequential(*layers)
38
-
39
- def forward(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
40
- h, w = x.shape[-2], x.shape[-1]
41
- pad_b = int(math.ceil(h / 2) * 2 - h)
42
- pad_r = int(math.ceil(w / 2) * 2 - w)
43
- x = F.pad(x, (0, pad_r, 0, pad_b), mode="replicate")
44
-
45
- x = F.pixel_unshuffle(x, self.sf)
46
- noise_map = sigma.view(-1, 1, 1, 1).expand(x.shape[0], 1, x.shape[-2], x.shape[-1])
47
- x = torch.cat([x, noise_map], dim=1)
48
- x = self.model(x)
49
- x = F.pixel_shuffle(x, self.sf)
50
- return x[..., :h, :w]