Update extract_feature_print.py
Browse files- extract_feature_print.py +4 -2
extract_feature_print.py
CHANGED
|
@@ -74,8 +74,10 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
| 74 |
[model_path],
|
| 75 |
suffix="",
|
| 76 |
)
|
| 77 |
-
|
| 78 |
-
model = models[0]
|
|
|
|
|
|
|
| 79 |
model = model.to(device)
|
| 80 |
printt("move model to %s" % device)
|
| 81 |
if device not in ["mps", "cpu"]:
|
|
|
|
| 74 |
[model_path],
|
| 75 |
suffix="",
|
| 76 |
)
|
| 77 |
+
if os.path.split(model_path)[0] == "/kaggle/working/Mangio-RVC-Fork/Custom" == False:
|
| 78 |
+
model = models[0]
|
| 79 |
+
else:
|
| 80 |
+
model = HubertModelWithFinalProj.from_pretrained(model_path)
|
| 81 |
model = model.to(device)
|
| 82 |
printt("move model to %s" % device)
|
| 83 |
if device not in ["mps", "cpu"]:
|