vncgabriel commited on
Commit
4f36d1b
·
verified ·
1 Parent(s): 9538b1c

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +25 -0
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class UNet(nn.Module):
5
+ def __init__(self, in_channels=3, out_channels=1):
6
+ super(UNet, self).__init__()
7
+ # Definição simplificada da UNet
8
+ self.enc1 = nn.Sequential(
9
+ nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
10
+ nn.ReLU(),
11
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
12
+ nn.ReLU()
13
+ )
14
+ self.pool = nn.MaxPool2d(2)
15
+ self.dec1 = nn.Sequential(
16
+ nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
17
+ nn.ReLU(),
18
+ nn.Conv2d(64, out_channels, kernel_size=1)
19
+ )
20
+
21
+ def forward(self, x):
22
+ x1 = self.enc1(x)
23
+ x2 = self.pool(x1)
24
+ x3 = self.dec1(x2)
25
+ return x3