File size: 3,101 Bytes
0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b 74b58ba 0c96c9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import torch
import torch.nn as nn
import torch.nn.functional as F
# Definición de la arquitectura UNet (la misma utilizada en el entrenamiento).
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
self.encoder5 = self.conv_block(512, 1024)
self.bottleneck = self.conv_block(1024, 2048)
self.upconv5 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
self.decoder5 = self.conv_block(2048, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = self.conv_block(128, 64)
self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(F.max_pool2d(enc1, 2))
enc3 = self.encoder3(F.max_pool2d(enc2, 2))
enc4 = self.encoder4(F.max_pool2d(enc3, 2))
enc5 = self.encoder5(F.max_pool2d(enc4, 2))
bottleneck = self.bottleneck(F.max_pool2d(enc5, 2))
dec5 = self.upconv5(bottleneck)
dec5 = torch.cat((enc5, dec5), dim=1)
dec5 = self.decoder5(dec5)
dec4 = self.upconv4(dec5)
dec4 = torch.cat((enc4, dec4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((enc3, dec3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((enc2, dec2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((enc1, dec1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv_last(dec1))
def load_model(model_path, device='cpu'):
"""
Carga el modelo UNet con los pesos desde 'model_path'.
"""
model = UNet().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
def predict(model, image_tensor):
"""
Realiza la predicción de la máscara de instancias para una imagen.
- model: modelo cargado (UNet).
- image_tensor: tensor FloatTensor [C,H,W] normalizado.
Retorna un tensor [1,H,W] con probabilidades/máscara.
"""
with torch.no_grad():
output = model(image_tensor.unsqueeze(0))
return output.squeeze(0)
|