Spaces:
Build error
Build error
Update inference.py
Browse files- inference.py +6 -6
inference.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
-
#
|
| 6 |
class UNet(nn.Module):
|
| 7 |
def __init__(self):
|
| 8 |
super(UNet, self).__init__()
|
|
@@ -10,7 +10,7 @@ class UNet(nn.Module):
|
|
| 10 |
self.encoder2 = self.conv_block(64, 128)
|
| 11 |
self.encoder3 = self.conv_block(128, 256)
|
| 12 |
self.encoder4 = self.conv_block(256, 512)
|
| 13 |
-
self.encoder5 = self.conv_block(512, 1024)
|
| 14 |
self.bottleneck = self.conv_block(1024, 2048)
|
| 15 |
self.upconv5 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
|
| 16 |
self.decoder5 = self.conv_block(2048, 1024)
|
|
@@ -63,7 +63,7 @@ class UNet(nn.Module):
|
|
| 63 |
|
| 64 |
def load_model(model_path, device='cpu'):
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
"""
|
| 68 |
model = UNet().to(device)
|
| 69 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
@@ -73,10 +73,10 @@ def load_model(model_path, device='cpu'):
|
|
| 73 |
|
| 74 |
def predict(model, image_tensor):
|
| 75 |
"""
|
| 76 |
-
Realiza
|
| 77 |
-
- model: modelo
|
| 78 |
- image_tensor: tensor FloatTensor [C,H,W] normalizado.
|
| 79 |
-
Retorna
|
| 80 |
"""
|
| 81 |
with torch.no_grad():
|
| 82 |
output = model(image_tensor.unsqueeze(0))
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
+
# Definição da arquitetura UNet (idêntica à do treinamento).
|
| 6 |
class UNet(nn.Module):
|
| 7 |
def __init__(self):
|
| 8 |
super(UNet, self).__init__()
|
|
|
|
| 10 |
self.encoder2 = self.conv_block(64, 128)
|
| 11 |
self.encoder3 = self.conv_block(128, 256)
|
| 12 |
self.encoder4 = self.conv_block(256, 512)
|
| 13 |
+
self.encoder5 = self.conv_block(512, 1024) # Camada adicional
|
| 14 |
self.bottleneck = self.conv_block(1024, 2048)
|
| 15 |
self.upconv5 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
|
| 16 |
self.decoder5 = self.conv_block(2048, 1024)
|
|
|
|
| 63 |
|
| 64 |
def load_model(model_path, device='cpu'):
|
| 65 |
"""
|
| 66 |
+
Carrega o modelo UNet com os pesos de 'model_path'.
|
| 67 |
"""
|
| 68 |
model = UNet().to(device)
|
| 69 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
|
| 73 |
|
| 74 |
def predict(model, image_tensor):
|
| 75 |
"""
|
| 76 |
+
Realiza a predição da máscara de instância para uma imagem.
|
| 77 |
+
- model: modelo carregado (UNet).
|
| 78 |
- image_tensor: tensor FloatTensor [C,H,W] normalizado.
|
| 79 |
+
Retorna um tensor [1,H,W] com as probabilidades/máscara.
|
| 80 |
"""
|
| 81 |
with torch.no_grad():
|
| 82 |
output = model(image_tensor.unsqueeze(0))
|