Fix `Tensor.item() cannot be called on meta tensors` error during loading with transformers>=5

#1
by zhoukz - opened
Files changed (1) hide show
  1. modeling_dasheng.py +1 -1
modeling_dasheng.py CHANGED
@@ -321,7 +321,7 @@ class AudioTransformerMAE_Encoder(nn.Module):
321
 
322
  norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
323
  act_layer = act_layer or nn.GELU
324
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
325
  self.pos_drop = nn.Dropout(p=drop_rate)
326
  block_function = globals()[block_type]
327
  self.blocks = nn.Sequential(
 
321
 
322
  norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
323
  act_layer = act_layer or nn.GELU
324
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device="cpu")] # stochastic depth decay rule
325
  self.pos_drop = nn.Dropout(p=drop_rate)
326
  block_function = globals()[block_type]
327
  self.blocks = nn.Sequential(