| |
| |
| |
|
|
| import jax.numpy as jnp |
| import flax.linen as nn |
|
|
|
|
| class RDB_Conv(nn.Module): |
| growRate: int |
| kSize: int = 3 |
|
|
| @nn.compact |
| def __call__(self, x): |
| out = nn.Sequential([ |
| nn.Conv(self.growRate, (self.kSize, self.kSize), padding=(self.kSize-1)//2), |
| nn.activation.relu |
| ])(x) |
| return jnp.concatenate((x, out), -1) |
|
|
|
|
| class RDB(nn.Module): |
| growRate0: int |
| growRate: int |
| nConvLayers: int |
|
|
| @nn.compact |
| def __call__(self, x): |
| res = x |
|
|
| for c in range(self.nConvLayers): |
| x = RDB_Conv(self.growRate)(x) |
|
|
| x = nn.Conv(self.growRate0, (1, 1))(x) |
|
|
| return x + res |
|
|
|
|
| class RDN(nn.Module): |
| G0: int = 64 |
| RDNkSize: int = 3 |
| RDNconfig: str = 'B' |
| scale: int = 2 |
| n_colors: int = 3 |
|
|
| @nn.compact |
| def __call__(self, x, _=None): |
| D, C, G = { |
| 'A': (20, 6, 32), |
| 'B': (16, 8, 64), |
| }[self.RDNconfig] |
|
|
| |
| f_1 = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(x) |
| x = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(f_1) |
|
|
| |
| RDBs_out = [] |
| for i in range(D): |
| x = RDB(self.G0, G, C)(x) |
| RDBs_out.append(x) |
|
|
| x = jnp.concatenate(RDBs_out, -1) |
|
|
| |
| x = nn.Sequential([ |
| nn.Conv(self.G0, (1, 1)), |
| nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize)) |
| ])(x) |
|
|
| x = x + f_1 |
| return x |
|
|