Commit
·
0bce61e
1
Parent(s):
7c09185
For the compatibility with the meta device used in transformers==5.0.0.
Browse files- birefnet.py +3 -3
birefnet.py
CHANGED
|
@@ -385,7 +385,7 @@ class PyramidVisionTransformerImpr(nn.Module):
|
|
| 385 |
embed_dim=embed_dims[3])
|
| 386 |
|
| 387 |
# transformer encoder
|
| 388 |
-
dpr =
|
| 389 |
cur = 0
|
| 390 |
self.block1 = nn.ModuleList([Block(
|
| 391 |
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
@@ -444,7 +444,7 @@ class PyramidVisionTransformerImpr(nn.Module):
|
|
| 444 |
#load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
| 445 |
|
| 446 |
def reset_drop_path(self, drop_path_rate):
|
| 447 |
-
dpr =
|
| 448 |
cur = 0
|
| 449 |
for i in range(self.depths[0]):
|
| 450 |
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
|
@@ -1130,7 +1130,7 @@ class SwinTransformer(nn.Module):
|
|
| 1130 |
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 1131 |
|
| 1132 |
# stochastic depth
|
| 1133 |
-
dpr =
|
| 1134 |
|
| 1135 |
# build layers
|
| 1136 |
self.layers = nn.ModuleList()
|
|
|
|
| 385 |
embed_dim=embed_dims[3])
|
| 386 |
|
| 387 |
# transformer encoder
|
| 388 |
+
dpr = np.linspace(0, drop_path_rate, sum(depths)).tolist() # stochastic depth decay rule
|
| 389 |
cur = 0
|
| 390 |
self.block1 = nn.ModuleList([Block(
|
| 391 |
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
| 444 |
#load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
| 445 |
|
| 446 |
def reset_drop_path(self, drop_path_rate):
|
| 447 |
+
dpr = np.linspace(0, drop_path_rate, sum(self.depths)).tolist()
|
| 448 |
cur = 0
|
| 449 |
for i in range(self.depths[0]):
|
| 450 |
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
|
|
|
| 1130 |
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 1131 |
|
| 1132 |
# stochastic depth
|
| 1133 |
+
dpr = np.linspace(0, drop_path_rate, sum(depths)).tolist() # stochastic depth decay rule
|
| 1134 |
|
| 1135 |
# build layers
|
| 1136 |
self.layers = nn.ModuleList()
|