evaluador / models.py
yoel
Refactor: reorganiza etiquetas y corrige validación de archivos en la interfaz de evaluación
302b2b5
import torch
import torch.nn as nn
from torchvision import models
class Stem(nn.Module):
def __init__(self):
super(Stem, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.MaxPool2d(kernel_size=3, stride=2),
)
def forward(self, x):
x = self.conv(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(inplace=True),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
)
self.shortcut = (
nn.Identity()
if in_channels == out_channels and stride == 1
else nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels),
)
)
self.act = nn.LeakyReLU(inplace=True)
def forward(self, x):
identity = self.shortcut(x)
x = self.conv1(x)
x = self.conv2(x)
x += identity
return self.act(x)
class FromZero(nn.Module):
def __init__(self, num_classes=10):
super(FromZero, self).__init__()
self.stem = nn.Sequential(Stem())
self.layer1 = nn.Sequential(ResidualBlock(64, 64), ResidualBlock(64, 64))
self.layer2 = nn.Sequential(
ResidualBlock(64, 128, stride=2), ResidualBlock(128, 128)
)
self.layer3 = nn.Sequential(
ResidualBlock(128, 256, stride=2), ResidualBlock(256, 256)
)
self.layer4 = nn.Sequential(
ResidualBlock(256, 512, stride=2), ResidualBlock(512, 512), nn.Dropout(0.2)
)
self.flatten = nn.Flatten()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(
nn.Linear(512, num_classes),
)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x