CPU edits
Browse files- birefnet.py +7 -3
birefnet.py
CHANGED
|
@@ -383,7 +383,7 @@ class PyramidVisionTransformerImpr(nn.Module):
|
|
| 383 |
embed_dim=embed_dims[3])
|
| 384 |
|
| 385 |
# transformer encoder
|
| 386 |
-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 387 |
cur = 0
|
| 388 |
self.block1 = nn.ModuleList([Block(
|
| 389 |
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
@@ -1128,8 +1128,12 @@ class SwinTransformer(nn.Module):
|
|
| 1128 |
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 1129 |
|
| 1130 |
# stochastic depth
|
| 1131 |
-
|
| 1132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1133 |
# build layers
|
| 1134 |
self.layers = nn.ModuleList()
|
| 1135 |
for i_layer in range(self.num_layers):
|
|
|
|
| 383 |
embed_dim=embed_dims[3])
|
| 384 |
|
| 385 |
# transformer encoder
|
| 386 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 387 |
cur = 0
|
| 388 |
self.block1 = nn.ModuleList([Block(
|
| 389 |
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
| 1128 |
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 1129 |
|
| 1130 |
# stochastic depth
|
| 1131 |
+
# stochastic depth decay rule (pure python: safe even if model is being initialized on `meta`)
|
| 1132 |
+
total_depth = int(sum(depths))
|
| 1133 |
+
if total_depth <= 1:
|
| 1134 |
+
dpr = [0.0]
|
| 1135 |
+
else:
|
| 1136 |
+
dpr = [float(drop_path_rate) * i / (total_depth - 1) for i in range(total_depth)]
|
| 1137 |
# build layers
|
| 1138 |
self.layers = nn.ModuleList()
|
| 1139 |
for i_layer in range(self.num_layers):
|