Update extract_feature_print.py
Browse files- extract_feature_print.py +2 -2
extract_feature_print.py
CHANGED
|
@@ -242,7 +242,7 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
| 242 |
[model_path],
|
| 243 |
suffix="",
|
| 244 |
)
|
| 245 |
-
if Custom_Embed == False:
|
| 246 |
model = models[0]
|
| 247 |
if device not in ["mps", "cpu"]:
|
| 248 |
model = model.half()
|
|
@@ -275,7 +275,7 @@ else:
|
|
| 275 |
if device not in ["mps", "cpu"]
|
| 276 |
else feats.to(device),
|
| 277 |
"padding_mask": padding_mask.to(device),
|
| 278 |
-
"output_layer": 9 if version == "v1" else 12, # layer 9
|
| 279 |
}
|
| 280 |
with torch.no_grad():
|
| 281 |
if Custom_Embed == False:
|
|
|
|
| 242 |
[model_path],
|
| 243 |
suffix="",
|
| 244 |
)
|
| 245 |
+
if Custom_Embed == False or sample_embedding == "hubert_large_ll60k":
|
| 246 |
model = models[0]
|
| 247 |
if device not in ["mps", "cpu"]:
|
| 248 |
model = model.half()
|
|
|
|
| 275 |
if device not in ["mps", "cpu"]
|
| 276 |
else feats.to(device),
|
| 277 |
"padding_mask": padding_mask.to(device),
|
| 278 |
+
"output_layer": 9 if version == "v1" else 12 if sample_embedding != "hubert_large_ll60k" else 24, # layer 9
|
| 279 |
}
|
| 280 |
with torch.no_grad():
|
| 281 |
if Custom_Embed == False:
|