Spaces:
Sleeping
Sleeping
yoel
Refactor: mejora la interfaz de evaluación y la gestión de modelos, añadiendo soporte para super-resolución y optimizando la carga de datos
b79aa7a | import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Stem(nn.Module): | |
| def __init__(self): | |
| super(Stem, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(3, 64, kernel_size=3, stride=2), | |
| nn.MaxPool2d(kernel_size=3, stride=2), | |
| ) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, stride=1): | |
| super().__init__() | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| nn.LeakyReLU(inplace=True), | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| self.shortcut = ( | |
| nn.Identity() | |
| if in_channels == out_channels and stride == 1 | |
| else nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=stride, bias=False | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| ) | |
| self.act = nn.LeakyReLU(inplace=True) | |
| def forward(self, x): | |
| identity = self.shortcut(x) | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x += identity | |
| return self.act(x) | |
| class FromZero(nn.Module): | |
| def __init__(self, num_classes=10): | |
| super(FromZero, self).__init__() | |
| self.stem = nn.Sequential(Stem()) | |
| self.layer1 = nn.Sequential(ResidualBlock(64, 64), ResidualBlock(64, 64)) | |
| self.layer2 = nn.Sequential( | |
| ResidualBlock(64, 128, stride=2), ResidualBlock(128, 128) | |
| ) | |
| self.layer3 = nn.Sequential( | |
| ResidualBlock(128, 256, stride=2), ResidualBlock(256, 256) | |
| ) | |
| self.layer4 = nn.Sequential( | |
| ResidualBlock(256, 512, stride=2), ResidualBlock(512, 512), nn.Dropout(0.2) | |
| ) | |
| self.flatten = nn.Flatten() | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Sequential( | |
| nn.Linear(512, num_classes), | |
| ) | |
| def forward(self, x): | |
| x = self.stem(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = self.flatten(x) | |
| x = self.fc(x) | |
| return x | |
| class SRResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): | |
| super().__init__() | |
| mid_channels = out_channels // 2 | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1), | |
| nn.GroupNorm(8, mid_channels), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d( | |
| mid_channels, | |
| mid_channels, | |
| kernel_size=kernel_size, | |
| padding=1, | |
| stride=stride, | |
| ), | |
| nn.GroupNorm(8, mid_channels), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1), | |
| nn.GroupNorm(8, out_channels), | |
| ) | |
| self.relu = nn.SiLU(inplace=True) | |
| if stride != 1 or in_channels != out_channels: | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| ), | |
| nn.GroupNorm(8, out_channels), | |
| ) | |
| else: | |
| self.shortcut = nn.Identity() | |
| def forward(self, x): | |
| shortcut = self.shortcut(x) | |
| out = self.conv1(x) | |
| out += shortcut | |
| out = self.relu(out) | |
| return out | |
| class DownBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3): | |
| super().__init__() | |
| self.skip_block = SRResidualBlock( | |
| in_channels, out_channels, kernel_size=kernel_size, stride=1 | |
| ) | |
| self.downsample_block = SRResidualBlock( | |
| out_channels, out_channels, kernel_size=kernel_size, stride=2 | |
| ) | |
| def forward(self, x): | |
| skip = self.skip_block(x) | |
| out_down = self.downsample_block(skip) | |
| return skip, out_down | |
| class SRUp(nn.Module): | |
| def __init__(self, in_channels, out_channels, scale_factor=2): | |
| super().__init__() | |
| self.conv = nn.Conv2d( | |
| in_channels, | |
| out_channels * (scale_factor**2), | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| self.shuffle = nn.PixelShuffle(scale_factor) | |
| self.act = nn.SiLU() | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.shuffle(x) | |
| return self.act(x) | |
| class SRHead(nn.Module): | |
| def __init__(self, channels, out_channels): | |
| super().__init__() | |
| self.sr1 = SRUp(channels, channels) | |
| self.sr2 = SRUp(channels, channels) | |
| self.res = SRResidualBlock(channels, channels) | |
| self.res2 = SRResidualBlock(channels, channels) | |
| self.head = nn.Conv2d(channels, out_channels, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| x = self.sr1(x) | |
| x = self.res(x) | |
| x = self.sr2(x) | |
| x = self.res2(x) | |
| x = self.head(x) | |
| return x | |
| class InputInjection(nn.Module): | |
| def __init__(self, in_channels, target_channels): | |
| super().__init__() | |
| self.proj = nn.Conv2d(in_channels, target_channels, kernel_size=1) | |
| def forward(self, x, target_size): | |
| x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) | |
| return self.proj(x) | |
| class UNetSR(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=3, | |
| out_channels=3, | |
| ): | |
| super().__init__() | |
| self.residual_scaling = 0.1 | |
| self.input_inject256 = InputInjection(in_channels, 256) | |
| self.input_inject128 = InputInjection(in_channels, 128) | |
| self.input_inject64 = InputInjection(in_channels, 64) | |
| self.down1 = DownBlock(in_channels, 64) | |
| self.down2 = DownBlock(64, 128) | |
| self.down3 = DownBlock(128, 256) | |
| self.bottleneck = SRResidualBlock(256, 256) | |
| self.pre_up3 = nn.ConvTranspose2d( | |
| 256, 128, kernel_size=3, stride=2, padding=1, output_padding=1 | |
| ) | |
| self.up3 = SRResidualBlock(128 + 256, 128) | |
| self.pre_up2 = nn.ConvTranspose2d( | |
| 128, 64, kernel_size=3, stride=2, padding=1, output_padding=1 | |
| ) | |
| self.up2 = SRResidualBlock(128 + 64, 64) | |
| self.pre_up1 = nn.ConvTranspose2d( | |
| 64, 64, kernel_size=3, stride=2, padding=1, output_padding=1 | |
| ) | |
| self.up1 = SRResidualBlock(128, 64) | |
| self.sr_head = SRHead(64, out_channels) | |
| def forward(self, x): | |
| residual = x | |
| x1, x = self.down1(x) | |
| x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:]) | |
| x2, x = self.down2(x) | |
| x = x + self.residual_scaling * self.input_inject128(residual, x.shape[2:]) | |
| x3, x = self.down3(x) | |
| x = self.bottleneck(x) | |
| x = self.pre_up3(x) | |
| x = torch.cat([x, x3], dim=1) | |
| x = self.up3(x) | |
| x = self.pre_up2(x) | |
| x = torch.cat([x, x2], dim=1) | |
| x = self.up2(x) | |
| x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:]) | |
| x = self.pre_up1(x) | |
| x = torch.cat([x, x1], dim=1) | |
| x = self.up1(x) | |
| x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:]) | |
| out = self.sr_head(x) | |
| out = ( | |
| out | |
| + F.interpolate( | |
| residual, size=out.shape[2:], mode="bilinear", align_corners=False | |
| ) | |
| * self.residual_scaling | |
| ) | |
| return out | |