Commit
·
560a60d
1
Parent(s):
ed25252
Upgrade the ways of importing timm modules (>=1.0.23).
Browse files- birefnet.py +10 -16
birefnet.py
CHANGED
|
@@ -166,8 +166,8 @@ import torch
|
|
| 166 |
import torch.nn as nn
|
| 167 |
from functools import partial
|
| 168 |
|
| 169 |
-
from timm.
|
| 170 |
-
|
| 171 |
|
| 172 |
import math
|
| 173 |
|
|
@@ -547,7 +547,6 @@ def _conv_filter(state_dict, patch_size=16):
|
|
| 547 |
return out_dict
|
| 548 |
|
| 549 |
|
| 550 |
-
## @register_model
|
| 551 |
class pvt_v2_b0(PyramidVisionTransformerImpr):
|
| 552 |
def __init__(self, **kwargs):
|
| 553 |
super(pvt_v2_b0, self).__init__(
|
|
@@ -557,7 +556,6 @@ class pvt_v2_b0(PyramidVisionTransformerImpr):
|
|
| 557 |
|
| 558 |
|
| 559 |
|
| 560 |
-
## @register_model
|
| 561 |
class pvt_v2_b1(PyramidVisionTransformerImpr):
|
| 562 |
def __init__(self, **kwargs):
|
| 563 |
super(pvt_v2_b1, self).__init__(
|
|
@@ -565,7 +563,6 @@ class pvt_v2_b1(PyramidVisionTransformerImpr):
|
|
| 565 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 566 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 567 |
|
| 568 |
-
## @register_model
|
| 569 |
class pvt_v2_b2(PyramidVisionTransformerImpr):
|
| 570 |
def __init__(self, in_channels=3, **kwargs):
|
| 571 |
super(pvt_v2_b2, self).__init__(
|
|
@@ -573,7 +570,6 @@ class pvt_v2_b2(PyramidVisionTransformerImpr):
|
|
| 573 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
| 574 |
drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
|
| 575 |
|
| 576 |
-
## @register_model
|
| 577 |
class pvt_v2_b3(PyramidVisionTransformerImpr):
|
| 578 |
def __init__(self, **kwargs):
|
| 579 |
super(pvt_v2_b3, self).__init__(
|
|
@@ -581,7 +577,6 @@ class pvt_v2_b3(PyramidVisionTransformerImpr):
|
|
| 581 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
| 582 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 583 |
|
| 584 |
-
## @register_model
|
| 585 |
class pvt_v2_b4(PyramidVisionTransformerImpr):
|
| 586 |
def __init__(self, **kwargs):
|
| 587 |
super(pvt_v2_b4, self).__init__(
|
|
@@ -590,7 +585,6 @@ class pvt_v2_b4(PyramidVisionTransformerImpr):
|
|
| 590 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 591 |
|
| 592 |
|
| 593 |
-
## @register_model
|
| 594 |
class pvt_v2_b5(PyramidVisionTransformerImpr):
|
| 595 |
def __init__(self, **kwargs):
|
| 596 |
super(pvt_v2_b5, self).__init__(
|
|
@@ -614,7 +608,7 @@ import torch.nn as nn
|
|
| 614 |
import torch.nn.functional as F
|
| 615 |
import torch.utils.checkpoint as checkpoint
|
| 616 |
import numpy as np
|
| 617 |
-
from timm.
|
| 618 |
|
| 619 |
# from config import Config
|
| 620 |
|
|
@@ -1195,7 +1189,7 @@ class SwinTransformer(nn.Module):
|
|
| 1195 |
# interpolate the position embedding to the corresponding size
|
| 1196 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 1197 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 1198 |
-
|
| 1199 |
outs = []#x.contiguous()]
|
| 1200 |
x = x.flatten(2).transpose(1, 2)
|
| 1201 |
x = self.pos_drop(x)
|
|
@@ -1252,13 +1246,13 @@ class DeformableConv2d(nn.Module):
|
|
| 1252 |
bias=False):
|
| 1253 |
|
| 1254 |
super(DeformableConv2d, self).__init__()
|
| 1255 |
-
|
| 1256 |
assert type(kernel_size) == tuple or type(kernel_size) == int
|
| 1257 |
|
| 1258 |
kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
|
| 1259 |
self.stride = stride if type(stride) == tuple else (stride, stride)
|
| 1260 |
self.padding = padding
|
| 1261 |
-
|
| 1262 |
self.offset_conv = nn.Conv2d(in_channels,
|
| 1263 |
2 * kernel_size[0] * kernel_size[1],
|
| 1264 |
kernel_size=kernel_size,
|
|
@@ -1268,7 +1262,7 @@ class DeformableConv2d(nn.Module):
|
|
| 1268 |
|
| 1269 |
nn.init.constant_(self.offset_conv.weight, 0.)
|
| 1270 |
nn.init.constant_(self.offset_conv.bias, 0.)
|
| 1271 |
-
|
| 1272 |
self.modulator_conv = nn.Conv2d(in_channels,
|
| 1273 |
1 * kernel_size[0] * kernel_size[1],
|
| 1274 |
kernel_size=kernel_size,
|
|
@@ -1292,7 +1286,7 @@ class DeformableConv2d(nn.Module):
|
|
| 1292 |
|
| 1293 |
offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
|
| 1294 |
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
| 1295 |
-
|
| 1296 |
x = deform_conv2d(
|
| 1297 |
input=x,
|
| 1298 |
offset=offset,
|
|
@@ -1490,7 +1484,7 @@ class ResBlk(nn.Module):
|
|
| 1490 |
|
| 1491 |
self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
|
| 1492 |
self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
| 1493 |
-
|
| 1494 |
self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
|
| 1495 |
|
| 1496 |
def forward(self, x):
|
|
@@ -2141,7 +2135,7 @@ class Decoder(nn.Module):
|
|
| 2141 |
self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2142 |
self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2143 |
self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2144 |
-
|
| 2145 |
self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2146 |
self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2147 |
self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
|
|
|
| 166 |
import torch.nn as nn
|
| 167 |
from functools import partial
|
| 168 |
|
| 169 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 170 |
+
|
| 171 |
|
| 172 |
import math
|
| 173 |
|
|
|
|
| 547 |
return out_dict
|
| 548 |
|
| 549 |
|
|
|
|
| 550 |
class pvt_v2_b0(PyramidVisionTransformerImpr):
|
| 551 |
def __init__(self, **kwargs):
|
| 552 |
super(pvt_v2_b0, self).__init__(
|
|
|
|
| 556 |
|
| 557 |
|
| 558 |
|
|
|
|
| 559 |
class pvt_v2_b1(PyramidVisionTransformerImpr):
|
| 560 |
def __init__(self, **kwargs):
|
| 561 |
super(pvt_v2_b1, self).__init__(
|
|
|
|
| 563 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 564 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 565 |
|
|
|
|
| 566 |
class pvt_v2_b2(PyramidVisionTransformerImpr):
|
| 567 |
def __init__(self, in_channels=3, **kwargs):
|
| 568 |
super(pvt_v2_b2, self).__init__(
|
|
|
|
| 570 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
| 571 |
drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
|
| 572 |
|
|
|
|
| 573 |
class pvt_v2_b3(PyramidVisionTransformerImpr):
|
| 574 |
def __init__(self, **kwargs):
|
| 575 |
super(pvt_v2_b3, self).__init__(
|
|
|
|
| 577 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
| 578 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 579 |
|
|
|
|
| 580 |
class pvt_v2_b4(PyramidVisionTransformerImpr):
|
| 581 |
def __init__(self, **kwargs):
|
| 582 |
super(pvt_v2_b4, self).__init__(
|
|
|
|
| 585 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 586 |
|
| 587 |
|
|
|
|
| 588 |
class pvt_v2_b5(PyramidVisionTransformerImpr):
|
| 589 |
def __init__(self, **kwargs):
|
| 590 |
super(pvt_v2_b5, self).__init__(
|
|
|
|
| 608 |
import torch.nn.functional as F
|
| 609 |
import torch.utils.checkpoint as checkpoint
|
| 610 |
import numpy as np
|
| 611 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 612 |
|
| 613 |
# from config import Config
|
| 614 |
|
|
|
|
| 1189 |
# interpolate the position embedding to the corresponding size
|
| 1190 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 1191 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 1192 |
+
|
| 1193 |
outs = []#x.contiguous()]
|
| 1194 |
x = x.flatten(2).transpose(1, 2)
|
| 1195 |
x = self.pos_drop(x)
|
|
|
|
| 1246 |
bias=False):
|
| 1247 |
|
| 1248 |
super(DeformableConv2d, self).__init__()
|
| 1249 |
+
|
| 1250 |
assert type(kernel_size) == tuple or type(kernel_size) == int
|
| 1251 |
|
| 1252 |
kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
|
| 1253 |
self.stride = stride if type(stride) == tuple else (stride, stride)
|
| 1254 |
self.padding = padding
|
| 1255 |
+
|
| 1256 |
self.offset_conv = nn.Conv2d(in_channels,
|
| 1257 |
2 * kernel_size[0] * kernel_size[1],
|
| 1258 |
kernel_size=kernel_size,
|
|
|
|
| 1262 |
|
| 1263 |
nn.init.constant_(self.offset_conv.weight, 0.)
|
| 1264 |
nn.init.constant_(self.offset_conv.bias, 0.)
|
| 1265 |
+
|
| 1266 |
self.modulator_conv = nn.Conv2d(in_channels,
|
| 1267 |
1 * kernel_size[0] * kernel_size[1],
|
| 1268 |
kernel_size=kernel_size,
|
|
|
|
| 1286 |
|
| 1287 |
offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
|
| 1288 |
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
| 1289 |
+
|
| 1290 |
x = deform_conv2d(
|
| 1291 |
input=x,
|
| 1292 |
offset=offset,
|
|
|
|
| 1484 |
|
| 1485 |
self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
|
| 1486 |
self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
| 1487 |
+
|
| 1488 |
self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
|
| 1489 |
|
| 1490 |
def forward(self, x):
|
|
|
|
| 2135 |
self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2136 |
self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2137 |
self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2138 |
+
|
| 2139 |
self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2140 |
self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2141 |
self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|