| | import torch |
| | import torch.nn as nn |
| |
|
| | class UNetGeneratorImproved(nn.Module): |
| | """ |
| | UNet simplificado para tareas de colorización. |
| | Entrada: imagen en escala de grises (1 canal) |
| | Salida: imagen RGB colorizada (3 canales) |
| | """ |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | |
| | self.encoder = nn.Sequential( |
| | nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), |
| | nn.ReLU(inplace=True), |
| |
|
| | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.ReLU(inplace=True), |
| |
|
| | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(256), |
| | nn.ReLU(inplace=True), |
| | ) |
| |
|
| | |
| | self.decoder = nn.Sequential( |
| | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.ReLU(inplace=True), |
| |
|
| | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU(inplace=True), |
| |
|
| | nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), |
| | nn.Tanh() |
| | ) |
| |
|
| | def forward(self, x): |
| | """ |
| | Propagación hacia adelante del modelo |
| | """ |
| | x = self.encoder(x) |
| | x = self.decoder(x) |
| | return x |