Pj12 commited on
Commit
af3434e
·
verified ·
1 Parent(s): fa8c613

Update extract_feature_print.py

Browse files
Files changed (1) hide show
  1. extract_feature_print.py +13 -2
extract_feature_print.py CHANGED
@@ -1,4 +1,13 @@
1
  import os, sys, traceback
 
 
 
 
 
 
 
 
 
2
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
  os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
@@ -60,7 +69,9 @@ def readwave(wav_path, normalize=False):
60
  feats = feats.view(1, -1)
61
  return feats
62
 
63
-
 
 
64
  # HuBERT model
65
  printt("load model(s) from {}".format(model_path))
66
  # if hubert model is exist
@@ -74,7 +85,7 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
74
  [model_path],
75
  suffix="",
76
  )
77
- if os.path.split(model_path)[0] == "/kaggle/working/Mangio-RVC-Fork/Custom" == False:
78
  model = models[0]
79
  else:
80
  model = HubertModelWithFinalProj.from_pretrained(model_path)
 
1
  import os, sys, traceback
2
+ from transformers import HubertModel
3
+ import os
4
+ from torch import nn
5
+
6
+ #HuggingFacePlaceHolder = None
7
+ class HubertModelWithFinalProj(HubertModel):
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
11
 
12
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
13
  os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
 
69
  feats = feats.view(1, -1)
70
  return feats
71
 
72
+ if os.path.split(model_path)[0] == "/kaggle/working/Mangio-RVC-Fork/Custom" == False:
73
+ model_path = "/kaggle/working/Mangio-RVC-Fork/hubert_base.pt"
74
+ Custom_Embed = True
75
  # HuBERT model
76
  printt("load model(s) from {}".format(model_path))
77
  # if hubert model is exist
 
85
  [model_path],
86
  suffix="",
87
  )
88
+ if Custom_Embed == False:
89
  model = models[0]
90
  else:
91
  model = HubertModelWithFinalProj.from_pretrained(model_path)