vncgabriel commited on
Commit
74b58ba
·
verified ·
1 Parent(s): 0c96c9b

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +15 -5
inference.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- # Definição da arquitetura UNet (mesma definida no treinamento).
6
  class UNet(nn.Module):
7
  def __init__(self):
8
  super(UNet, self).__init__()
@@ -23,11 +23,13 @@ class UNet(nn.Module):
23
  self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
24
  self.decoder1 = self.conv_block(128, 64)
25
  self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
 
26
  def conv_block(self, in_channels, out_channels):
27
  return nn.Sequential(
28
  nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(),
29
  nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()
30
  )
 
31
  def forward(self, x):
32
  enc1 = self.encoder1(x)
33
  enc2 = self.encoder2(F.max_pool2d(enc1, 2))
@@ -35,38 +37,46 @@ class UNet(nn.Module):
35
  enc4 = self.encoder4(F.max_pool2d(enc3, 2))
36
  enc5 = self.encoder5(F.max_pool2d(enc4, 2))
37
  bottleneck = self.bottleneck(F.max_pool2d(enc5, 2))
 
38
  dec5 = self.upconv5(bottleneck)
39
  dec5 = torch.cat((enc5, dec5), dim=1)
40
  dec5 = self.decoder5(dec5)
 
41
  dec4 = self.upconv4(dec5)
42
  dec4 = torch.cat((enc4, dec4), dim=1)
43
  dec4 = self.decoder4(dec4)
 
44
  dec3 = self.upconv3(dec4)
45
  dec3 = torch.cat((enc3, dec3), dim=1)
46
  dec3 = self.decoder3(dec3)
 
47
  dec2 = self.upconv2(dec3)
48
  dec2 = torch.cat((enc2, dec2), dim=1)
49
  dec2 = self.decoder2(dec2)
 
50
  dec1 = self.upconv1(dec2)
51
  dec1 = torch.cat((enc1, dec1), dim=1)
52
  dec1 = self.decoder1(dec1)
 
53
  return torch.sigmoid(self.conv_last(dec1))
54
 
 
55
  def load_model(model_path, device='cpu'):
56
  """
57
- Carrega o modelo UNet com os pesos de 'model_path'.
58
  """
59
  model = UNet().to(device)
60
  model.load_state_dict(torch.load(model_path, map_location=device))
61
  model.eval()
62
  return model
63
 
 
64
  def predict(model, image_tensor):
65
  """
66
- Realiza predição de máscara de instâncias para uma imagem.
67
- - model: modelo carregado (UNet).
68
  - image_tensor: tensor FloatTensor [C,H,W] normalizado.
69
- Retorna tensor [1,H,W] com probabilidades/mascara.
70
  """
71
  with torch.no_grad():
72
  output = model(image_tensor.unsqueeze(0))
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ # Definición de la arquitectura UNet (la misma utilizada en el entrenamiento).
6
  class UNet(nn.Module):
7
  def __init__(self):
8
  super(UNet, self).__init__()
 
23
  self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
24
  self.decoder1 = self.conv_block(128, 64)
25
  self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
26
+
27
  def conv_block(self, in_channels, out_channels):
28
  return nn.Sequential(
29
  nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(),
30
  nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()
31
  )
32
+
33
  def forward(self, x):
34
  enc1 = self.encoder1(x)
35
  enc2 = self.encoder2(F.max_pool2d(enc1, 2))
 
37
  enc4 = self.encoder4(F.max_pool2d(enc3, 2))
38
  enc5 = self.encoder5(F.max_pool2d(enc4, 2))
39
  bottleneck = self.bottleneck(F.max_pool2d(enc5, 2))
40
+
41
  dec5 = self.upconv5(bottleneck)
42
  dec5 = torch.cat((enc5, dec5), dim=1)
43
  dec5 = self.decoder5(dec5)
44
+
45
  dec4 = self.upconv4(dec5)
46
  dec4 = torch.cat((enc4, dec4), dim=1)
47
  dec4 = self.decoder4(dec4)
48
+
49
  dec3 = self.upconv3(dec4)
50
  dec3 = torch.cat((enc3, dec3), dim=1)
51
  dec3 = self.decoder3(dec3)
52
+
53
  dec2 = self.upconv2(dec3)
54
  dec2 = torch.cat((enc2, dec2), dim=1)
55
  dec2 = self.decoder2(dec2)
56
+
57
  dec1 = self.upconv1(dec2)
58
  dec1 = torch.cat((enc1, dec1), dim=1)
59
  dec1 = self.decoder1(dec1)
60
+
61
  return torch.sigmoid(self.conv_last(dec1))
62
 
63
+
64
  def load_model(model_path, device='cpu'):
65
  """
66
+ Carga el modelo UNet con los pesos desde 'model_path'.
67
  """
68
  model = UNet().to(device)
69
  model.load_state_dict(torch.load(model_path, map_location=device))
70
  model.eval()
71
  return model
72
 
73
+
74
  def predict(model, image_tensor):
75
  """
76
+ Realiza la predicción de la máscara de instancias para una imagen.
77
+ - model: modelo cargado (UNet).
78
  - image_tensor: tensor FloatTensor [C,H,W] normalizado.
79
+ Retorna un tensor [1,H,W] con probabilidades/máscara.
80
  """
81
  with torch.no_grad():
82
  output = model(image_tensor.unsqueeze(0))