merve HF Staff commited on
Commit
329f351
·
verified ·
1 Parent(s): 74c1fc0
Files changed (1) hide show
  1. birefnet.py +7 -3
birefnet.py CHANGED
@@ -383,7 +383,7 @@ class PyramidVisionTransformerImpr(nn.Module):
383
  embed_dim=embed_dims[3])
384
 
385
  # transformer encoder
386
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
387
  cur = 0
388
  self.block1 = nn.ModuleList([Block(
389
  dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
@@ -1128,8 +1128,12 @@ class SwinTransformer(nn.Module):
1128
  self.pos_drop = nn.Dropout(p=drop_rate)
1129
 
1130
  # stochastic depth
1131
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
1132
-
 
 
 
 
1133
  # build layers
1134
  self.layers = nn.ModuleList()
1135
  for i_layer in range(self.num_layers):
 
383
  embed_dim=embed_dims[3])
384
 
385
  # transformer encoder
386
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
387
  cur = 0
388
  self.block1 = nn.ModuleList([Block(
389
  dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
 
1128
  self.pos_drop = nn.Dropout(p=drop_rate)
1129
 
1130
  # stochastic depth
1131
+ # stochastic depth decay rule (pure python: safe even if model is being initialized on `meta`)
1132
+ total_depth = int(sum(depths))
1133
+ if total_depth <= 1:
1134
+ dpr = [0.0]
1135
+ else:
1136
+ dpr = [float(drop_path_rate) * i / (total_depth - 1) for i in range(total_depth)]
1137
  # build layers
1138
  self.layers = nn.ModuleList()
1139
  for i_layer in range(self.num_layers):