Upload u2net.py
Browse files- model/u2net.py +24 -3
model/u2net.py
CHANGED
|
@@ -1,4 +1,25 @@
|
|
| 1 |
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class REBNCONV(nn.Module):
|
| 7 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
| 8 |
+
super(REBNCONV, self).__init__()
|
| 9 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1*dirate, dilation=1*dirate)
|
| 10 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
| 11 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self.relu_s1(self.bn_s1(self.conv_s1(x)))
|
| 15 |
+
|
| 16 |
+
# Простой U2NET-заглушка
|
| 17 |
+
class U2NET(nn.Module):
|
| 18 |
+
def __init__(self, in_ch=3, out_ch=1):
|
| 19 |
+
super(U2NET, self).__init__()
|
| 20 |
+
self.rebnconv = REBNCONV(in_ch, out_ch)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return self.rebnconv(x)
|
| 24 |
+
|
| 25 |
+
__all__ = ["U2NET"]
|