Update extract_feature_print.py
Browse files- extract_feature_print.py +13 -2
extract_feature_print.py
CHANGED
|
@@ -1,4 +1,13 @@
|
|
| 1 |
import os, sys, traceback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 4 |
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
|
|
@@ -60,7 +69,9 @@ def readwave(wav_path, normalize=False):
|
|
| 60 |
feats = feats.view(1, -1)
|
| 61 |
return feats
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
# HuBERT model
|
| 65 |
printt("load model(s) from {}".format(model_path))
|
| 66 |
# if hubert model is exist
|
|
@@ -74,7 +85,7 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
| 74 |
[model_path],
|
| 75 |
suffix="",
|
| 76 |
)
|
| 77 |
-
if
|
| 78 |
model = models[0]
|
| 79 |
else:
|
| 80 |
model = HubertModelWithFinalProj.from_pretrained(model_path)
|
|
|
|
| 1 |
import os, sys, traceback
|
| 2 |
+
from transformers import HubertModel
|
| 3 |
+
import os
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
#HuggingFacePlaceHolder = None
|
| 7 |
+
class HubertModelWithFinalProj(HubertModel):
|
| 8 |
+
def __init__(self, config):
|
| 9 |
+
super().__init__(config)
|
| 10 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 11 |
|
| 12 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 13 |
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
|
|
|
|
| 69 |
feats = feats.view(1, -1)
|
| 70 |
return feats
|
| 71 |
|
| 72 |
+
if os.path.split(model_path)[0] == "/kaggle/working/Mangio-RVC-Fork/Custom" == False:
|
| 73 |
+
model_path = "/kaggle/working/Mangio-RVC-Fork/hubert_base.pt"
|
| 74 |
+
Custom_Embed = True
|
| 75 |
# HuBERT model
|
| 76 |
printt("load model(s) from {}".format(model_path))
|
| 77 |
# if hubert model is exist
|
|
|
|
| 85 |
[model_path],
|
| 86 |
suffix="",
|
| 87 |
)
|
| 88 |
+
if Custom_Embed == False:
|
| 89 |
model = models[0]
|
| 90 |
else:
|
| 91 |
model = HubertModelWithFinalProj.from_pretrained(model_path)
|