Pj12 commited on
Commit
c4f3474
·
verified ·
1 Parent(s): 7ad1011

Update extract_feature_print.py

Browse files
Files changed (1) hide show
  1. extract_feature_print.py +2 -2
extract_feature_print.py CHANGED
@@ -242,7 +242,7 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
242
  [model_path],
243
  suffix="",
244
  )
245
- if Custom_Embed == False:
246
  model = models[0]
247
  if device not in ["mps", "cpu"]:
248
  model = model.half()
@@ -275,7 +275,7 @@ else:
275
  if device not in ["mps", "cpu"]
276
  else feats.to(device),
277
  "padding_mask": padding_mask.to(device),
278
- "output_layer": 9 if version == "v1" else 12, # layer 9
279
  }
280
  with torch.no_grad():
281
  if Custom_Embed == False:
 
242
  [model_path],
243
  suffix="",
244
  )
245
+ if Custom_Embed == False or sample_embedding == "hubert_large_ll60k":
246
  model = models[0]
247
  if device not in ["mps", "cpu"]:
248
  model = model.half()
 
275
  if device not in ["mps", "cpu"]
276
  else feats.to(device),
277
  "padding_mask": padding_mask.to(device),
278
+ "output_layer": 9 if version == "v1" else 12 if sample_embedding != "hubert_large_ll60k" else 24, # layer 9
279
  }
280
  with torch.no_grad():
281
  if Custom_Embed == False: