Gabriele commited on
Commit
05bf02c
·
1 Parent(s): 2bad776

Fix meta device crash: use pure Python for drop_path_rate

Browse files
Files changed (1) hide show
  1. image_encoder.py +1 -1
image_encoder.py CHANGED
@@ -695,7 +695,7 @@ class VisionTransformer(nn.Module):
695
  dpr = [drop_path_rate] * depth
696
  else:
697
  dpr = [
698
- x.item() for x in torch.linspace(0, drop_path_rate, depth)
699
  ] # stochastic depth decay rule
700
 
701
  if ffn_layer == "mlp":
 
695
  dpr = [drop_path_rate] * depth
696
  else:
697
  dpr = [
698
+ drop_path_rate * i / max(depth - 1, 1) for i in range(depth)
699
  ] # stochastic depth decay rule
700
 
701
  if ffn_layer == "mlp":