ZhengPeng7 commited on
Commit
560a60d
·
1 Parent(s): ed25252

Upgrade the ways of importing timm modules (>=1.0.23).

Browse files
Files changed (1) hide show
  1. birefnet.py +10 -16
birefnet.py CHANGED
@@ -166,8 +166,8 @@ import torch
166
  import torch.nn as nn
167
  from functools import partial
168
 
169
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
170
- from timm.models.registry import register_model
171
 
172
  import math
173
 
@@ -547,7 +547,6 @@ def _conv_filter(state_dict, patch_size=16):
547
  return out_dict
548
 
549
 
550
- ## @register_model
551
  class pvt_v2_b0(PyramidVisionTransformerImpr):
552
  def __init__(self, **kwargs):
553
  super(pvt_v2_b0, self).__init__(
@@ -557,7 +556,6 @@ class pvt_v2_b0(PyramidVisionTransformerImpr):
557
 
558
 
559
 
560
- ## @register_model
561
  class pvt_v2_b1(PyramidVisionTransformerImpr):
562
  def __init__(self, **kwargs):
563
  super(pvt_v2_b1, self).__init__(
@@ -565,7 +563,6 @@ class pvt_v2_b1(PyramidVisionTransformerImpr):
565
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
566
  drop_rate=0.0, drop_path_rate=0.1)
567
 
568
- ## @register_model
569
  class pvt_v2_b2(PyramidVisionTransformerImpr):
570
  def __init__(self, in_channels=3, **kwargs):
571
  super(pvt_v2_b2, self).__init__(
@@ -573,7 +570,6 @@ class pvt_v2_b2(PyramidVisionTransformerImpr):
573
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
574
  drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
575
 
576
- ## @register_model
577
  class pvt_v2_b3(PyramidVisionTransformerImpr):
578
  def __init__(self, **kwargs):
579
  super(pvt_v2_b3, self).__init__(
@@ -581,7 +577,6 @@ class pvt_v2_b3(PyramidVisionTransformerImpr):
581
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
582
  drop_rate=0.0, drop_path_rate=0.1)
583
 
584
- ## @register_model
585
  class pvt_v2_b4(PyramidVisionTransformerImpr):
586
  def __init__(self, **kwargs):
587
  super(pvt_v2_b4, self).__init__(
@@ -590,7 +585,6 @@ class pvt_v2_b4(PyramidVisionTransformerImpr):
590
  drop_rate=0.0, drop_path_rate=0.1)
591
 
592
 
593
- ## @register_model
594
  class pvt_v2_b5(PyramidVisionTransformerImpr):
595
  def __init__(self, **kwargs):
596
  super(pvt_v2_b5, self).__init__(
@@ -614,7 +608,7 @@ import torch.nn as nn
614
  import torch.nn.functional as F
615
  import torch.utils.checkpoint as checkpoint
616
  import numpy as np
617
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
618
 
619
  # from config import Config
620
 
@@ -1195,7 +1189,7 @@ class SwinTransformer(nn.Module):
1195
  # interpolate the position embedding to the corresponding size
1196
  absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
1197
  x = (x + absolute_pos_embed) # B Wh*Ww C
1198
-
1199
  outs = []#x.contiguous()]
1200
  x = x.flatten(2).transpose(1, 2)
1201
  x = self.pos_drop(x)
@@ -1252,13 +1246,13 @@ class DeformableConv2d(nn.Module):
1252
  bias=False):
1253
 
1254
  super(DeformableConv2d, self).__init__()
1255
-
1256
  assert type(kernel_size) == tuple or type(kernel_size) == int
1257
 
1258
  kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
1259
  self.stride = stride if type(stride) == tuple else (stride, stride)
1260
  self.padding = padding
1261
-
1262
  self.offset_conv = nn.Conv2d(in_channels,
1263
  2 * kernel_size[0] * kernel_size[1],
1264
  kernel_size=kernel_size,
@@ -1268,7 +1262,7 @@ class DeformableConv2d(nn.Module):
1268
 
1269
  nn.init.constant_(self.offset_conv.weight, 0.)
1270
  nn.init.constant_(self.offset_conv.bias, 0.)
1271
-
1272
  self.modulator_conv = nn.Conv2d(in_channels,
1273
  1 * kernel_size[0] * kernel_size[1],
1274
  kernel_size=kernel_size,
@@ -1292,7 +1286,7 @@ class DeformableConv2d(nn.Module):
1292
 
1293
  offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
1294
  modulator = 2. * torch.sigmoid(self.modulator_conv(x))
1295
-
1296
  x = deform_conv2d(
1297
  input=x,
1298
  offset=offset,
@@ -1490,7 +1484,7 @@ class ResBlk(nn.Module):
1490
 
1491
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1492
  self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1493
-
1494
  self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1495
 
1496
  def forward(self, x):
@@ -2141,7 +2135,7 @@ class Decoder(nn.Module):
2141
  self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2142
  self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2143
  self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2144
-
2145
  self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2146
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2147
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
 
166
  import torch.nn as nn
167
  from functools import partial
168
 
169
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
170
+
171
 
172
  import math
173
 
 
547
  return out_dict
548
 
549
 
 
550
  class pvt_v2_b0(PyramidVisionTransformerImpr):
551
  def __init__(self, **kwargs):
552
  super(pvt_v2_b0, self).__init__(
 
556
 
557
 
558
 
 
559
  class pvt_v2_b1(PyramidVisionTransformerImpr):
560
  def __init__(self, **kwargs):
561
  super(pvt_v2_b1, self).__init__(
 
563
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
564
  drop_rate=0.0, drop_path_rate=0.1)
565
 
 
566
  class pvt_v2_b2(PyramidVisionTransformerImpr):
567
  def __init__(self, in_channels=3, **kwargs):
568
  super(pvt_v2_b2, self).__init__(
 
570
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
571
  drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
572
 
 
573
  class pvt_v2_b3(PyramidVisionTransformerImpr):
574
  def __init__(self, **kwargs):
575
  super(pvt_v2_b3, self).__init__(
 
577
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
578
  drop_rate=0.0, drop_path_rate=0.1)
579
 
 
580
  class pvt_v2_b4(PyramidVisionTransformerImpr):
581
  def __init__(self, **kwargs):
582
  super(pvt_v2_b4, self).__init__(
 
585
  drop_rate=0.0, drop_path_rate=0.1)
586
 
587
 
 
588
  class pvt_v2_b5(PyramidVisionTransformerImpr):
589
  def __init__(self, **kwargs):
590
  super(pvt_v2_b5, self).__init__(
 
608
  import torch.nn.functional as F
609
  import torch.utils.checkpoint as checkpoint
610
  import numpy as np
611
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
612
 
613
  # from config import Config
614
 
 
1189
  # interpolate the position embedding to the corresponding size
1190
  absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
1191
  x = (x + absolute_pos_embed) # B Wh*Ww C
1192
+
1193
  outs = []#x.contiguous()]
1194
  x = x.flatten(2).transpose(1, 2)
1195
  x = self.pos_drop(x)
 
1246
  bias=False):
1247
 
1248
  super(DeformableConv2d, self).__init__()
1249
+
1250
  assert type(kernel_size) == tuple or type(kernel_size) == int
1251
 
1252
  kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
1253
  self.stride = stride if type(stride) == tuple else (stride, stride)
1254
  self.padding = padding
1255
+
1256
  self.offset_conv = nn.Conv2d(in_channels,
1257
  2 * kernel_size[0] * kernel_size[1],
1258
  kernel_size=kernel_size,
 
1262
 
1263
  nn.init.constant_(self.offset_conv.weight, 0.)
1264
  nn.init.constant_(self.offset_conv.bias, 0.)
1265
+
1266
  self.modulator_conv = nn.Conv2d(in_channels,
1267
  1 * kernel_size[0] * kernel_size[1],
1268
  kernel_size=kernel_size,
 
1286
 
1287
  offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
1288
  modulator = 2. * torch.sigmoid(self.modulator_conv(x))
1289
+
1290
  x = deform_conv2d(
1291
  input=x,
1292
  offset=offset,
 
1484
 
1485
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1486
  self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1487
+
1488
  self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1489
 
1490
  def forward(self, x):
 
2135
  self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2136
  self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2137
  self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2138
+
2139
  self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2140
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2141
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))