Spaces:
Running on Zero
Running on Zero
| from typing import Literal | |
| import torch | |
| import torch.nn as nn | |
| ConvMode = Literal["CNA", "NAC", "CNAC"] | |
| def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module: | |
| """Get activation layer (LeakyReLU).""" | |
| return nn.LeakyReLU(neg_slope, inplace) | |
| def get_valid_padding(kernel_size: int, dilation: int) -> int: | |
| """Calculate padding for 'same' convolution.""" | |
| return (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 | |
| def sequential(*args: nn.Module) -> nn.Sequential: | |
| """Flatten nested Sequential modules into one.""" | |
| modules = [] | |
| for m in args: | |
| if isinstance(m, nn.Sequential): | |
| modules.extend(m.children()) | |
| elif isinstance(m, nn.Module): | |
| modules.append(m) | |
| return nn.Sequential(*modules) | |
| def conv_block(in_nc: int, out_nc: int, kernel_size: int, stride: int = 1, dilation: int = 1, | |
| groups: int = 1, bias: bool = True, pad_type: str = "zero", norm_type=None, | |
| act_type: str | None = "relu", mode: ConvMode = "CNA", c2x2: bool = False) -> nn.Sequential: | |
| """Create Conv-Norm-Act block.""" | |
| padding = get_valid_padding(kernel_size, dilation) if pad_type == "zero" else 0 | |
| c = nn.Conv2d(in_nc, out_nc, kernel_size, stride, padding, dilation, groups, bias) | |
| a = act(act_type) if act_type else None | |
| return sequential(c, a) if mode in ("CNA", "CNAC") else sequential(c) | |
| def upconv_block(in_nc: int, out_nc: int, upscale_factor: int = 2, kernel_size: int = 3, | |
| stride: int = 1, bias: bool = True, pad_type: str = "zero", norm_type=None, | |
| act_type: str = "relu", mode: str = "nearest", c2x2: bool = False) -> nn.Sequential: | |
| """Create Upsample + Conv block.""" | |
| return sequential( | |
| nn.Upsample(scale_factor=upscale_factor, mode=mode), | |
| conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, act_type=act_type) | |
| ) | |
| class ShortcutBlock(nn.Module): | |
| """Residual block: x + submodule(x).""" | |
| def __init__(self, submodule: nn.Module): | |
| super().__init__() | |
| self.sub = submodule | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.sub(x) | |