evaluador / models.py
yoel
Refactor: mejora la interfaz de evaluación y la gestión de modelos, añadiendo soporte para super-resolución y optimizando la carga de datos
b79aa7a
import torch
import torch.nn as nn
import torch.nn.functional as F
class Stem(nn.Module):
def __init__(self):
super(Stem, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.MaxPool2d(kernel_size=3, stride=2),
)
def forward(self, x):
x = self.conv(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(inplace=True),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
)
self.shortcut = (
nn.Identity()
if in_channels == out_channels and stride == 1
else nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels),
)
)
self.act = nn.LeakyReLU(inplace=True)
def forward(self, x):
identity = self.shortcut(x)
x = self.conv1(x)
x = self.conv2(x)
x += identity
return self.act(x)
class FromZero(nn.Module):
def __init__(self, num_classes=10):
super(FromZero, self).__init__()
self.stem = nn.Sequential(Stem())
self.layer1 = nn.Sequential(ResidualBlock(64, 64), ResidualBlock(64, 64))
self.layer2 = nn.Sequential(
ResidualBlock(64, 128, stride=2), ResidualBlock(128, 128)
)
self.layer3 = nn.Sequential(
ResidualBlock(128, 256, stride=2), ResidualBlock(256, 256)
)
self.layer4 = nn.Sequential(
ResidualBlock(256, 512, stride=2), ResidualBlock(512, 512), nn.Dropout(0.2)
)
self.flatten = nn.Flatten()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(
nn.Linear(512, num_classes),
)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
class SRResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
super().__init__()
mid_channels = out_channels // 2
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1),
nn.GroupNorm(8, mid_channels),
nn.SiLU(inplace=True),
nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=kernel_size,
padding=1,
stride=stride,
),
nn.GroupNorm(8, mid_channels),
nn.SiLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1),
nn.GroupNorm(8, out_channels),
)
self.relu = nn.SiLU(inplace=True)
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
),
nn.GroupNorm(8, out_channels),
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
shortcut = self.shortcut(x)
out = self.conv1(x)
out += shortcut
out = self.relu(out)
return out
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3):
super().__init__()
self.skip_block = SRResidualBlock(
in_channels, out_channels, kernel_size=kernel_size, stride=1
)
self.downsample_block = SRResidualBlock(
out_channels, out_channels, kernel_size=kernel_size, stride=2
)
def forward(self, x):
skip = self.skip_block(x)
out_down = self.downsample_block(skip)
return skip, out_down
class SRUp(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor=2):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels * (scale_factor**2),
kernel_size=3,
padding=1,
)
self.shuffle = nn.PixelShuffle(scale_factor)
self.act = nn.SiLU()
def forward(self, x):
x = self.conv(x)
x = self.shuffle(x)
return self.act(x)
class SRHead(nn.Module):
def __init__(self, channels, out_channels):
super().__init__()
self.sr1 = SRUp(channels, channels)
self.sr2 = SRUp(channels, channels)
self.res = SRResidualBlock(channels, channels)
self.res2 = SRResidualBlock(channels, channels)
self.head = nn.Conv2d(channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
x = self.sr1(x)
x = self.res(x)
x = self.sr2(x)
x = self.res2(x)
x = self.head(x)
return x
class InputInjection(nn.Module):
def __init__(self, in_channels, target_channels):
super().__init__()
self.proj = nn.Conv2d(in_channels, target_channels, kernel_size=1)
def forward(self, x, target_size):
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
return self.proj(x)
class UNetSR(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
):
super().__init__()
self.residual_scaling = 0.1
self.input_inject256 = InputInjection(in_channels, 256)
self.input_inject128 = InputInjection(in_channels, 128)
self.input_inject64 = InputInjection(in_channels, 64)
self.down1 = DownBlock(in_channels, 64)
self.down2 = DownBlock(64, 128)
self.down3 = DownBlock(128, 256)
self.bottleneck = SRResidualBlock(256, 256)
self.pre_up3 = nn.ConvTranspose2d(
256, 128, kernel_size=3, stride=2, padding=1, output_padding=1
)
self.up3 = SRResidualBlock(128 + 256, 128)
self.pre_up2 = nn.ConvTranspose2d(
128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
)
self.up2 = SRResidualBlock(128 + 64, 64)
self.pre_up1 = nn.ConvTranspose2d(
64, 64, kernel_size=3, stride=2, padding=1, output_padding=1
)
self.up1 = SRResidualBlock(128, 64)
self.sr_head = SRHead(64, out_channels)
def forward(self, x):
residual = x
x1, x = self.down1(x)
x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:])
x2, x = self.down2(x)
x = x + self.residual_scaling * self.input_inject128(residual, x.shape[2:])
x3, x = self.down3(x)
x = self.bottleneck(x)
x = self.pre_up3(x)
x = torch.cat([x, x3], dim=1)
x = self.up3(x)
x = self.pre_up2(x)
x = torch.cat([x, x2], dim=1)
x = self.up2(x)
x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:])
x = self.pre_up1(x)
x = torch.cat([x, x1], dim=1)
x = self.up1(x)
x = x + self.residual_scaling * self.input_inject64(residual, x.shape[2:])
out = self.sr_head(x)
out = (
out
+ F.interpolate(
residual, size=out.shape[2:], mode="bilinear", align_corners=False
)
* self.residual_scaling
)
return out