Pj12 commited on
Commit
b2cb5ef
·
verified ·
1 Parent(s): b82aafb

Update extract_feature_print.py

Browse files
Files changed (1) hide show
  1. extract_feature_print.py +2 -2
extract_feature_print.py CHANGED
@@ -140,7 +140,7 @@ class HubertModelWithFinalProj(HubertModel):
140
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
141
  os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
142
 
143
- # device=sys.argv[1]
144
  n_part = int(sys.argv[2])
145
  i_part = int(sys.argv[3])
146
  if len(sys.argv) == 6:
@@ -157,7 +157,7 @@ import soundfile as sf
157
  import numpy as np
158
  from fairseq import checkpoint_utils
159
 
160
- device = "cpu"
161
  if torch.cuda.is_available():
162
  device = "cuda"
163
  elif torch.backends.mps.is_available():
 
140
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
141
  os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
142
 
143
+ device=sys.argv[1]
144
  n_part = int(sys.argv[2])
145
  i_part = int(sys.argv[3])
146
  if len(sys.argv) == 6:
 
157
  import numpy as np
158
  from fairseq import checkpoint_utils
159
 
160
+ #device = "cpu"
161
  if torch.cuda.is_available():
162
  device = "cuda"
163
  elif torch.backends.mps.is_available():