bluspater commited on
Commit
33c684d
·
verified ·
1 Parent(s): b74b143

Upload u2net.py

Browse files
Files changed (1) hide show
  1. model/u2net.py +24 -3
model/u2net.py CHANGED
@@ -1,4 +1,25 @@
1
 
2
- # Здесь должна быть модель U2NET
3
- # Предполагается, что файл model/u2net.py уже содержит реализацию U2NET
4
- # Файл будет скопирован позже
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1*dirate, dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self, x):
14
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
15
+
16
+ # Простой U2NET-заглушка
17
+ class U2NET(nn.Module):
18
+ def __init__(self, in_ch=3, out_ch=1):
19
+ super(U2NET, self).__init__()
20
+ self.rebnconv = REBNCONV(in_ch, out_ch)
21
+
22
+ def forward(self, x):
23
+ return self.rebnconv(x)
24
+
25
+ __all__ = ["U2NET"]