Pj12 commited on
Commit
a26ad04
·
verified ·
1 Parent(s): 746fa95

Update extract_feature_print.py

Browse files
Files changed (1) hide show
  1. extract_feature_print.py +2 -2
extract_feature_print.py CHANGED
@@ -157,9 +157,9 @@ def main():
157
 
158
  # Resolve model path and name
159
  custom_mappings = {
 
160
  "hubert_base": ("hubert_base.pt", "hubert"),
161
  "contentvec_base": ("contentvec_base.pt", "contentvec"),
162
- "hubert_large_ll60k": ("hubert_large_ll60k.pt", "hubert"),
163
  }
164
  if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
165
  model_path, resolved_model_name = custom_mappings[model_name]
@@ -170,7 +170,7 @@ def main():
170
  sys.exit(1)
171
 
172
  # Load model
173
- model_config = model_configs.get(model_name, model_configs["hubert"])
174
  model_dict = model_config["load_model"](model_path, config.device, config.is_half)
175
  model = model_dict["model"]
176
  additional_configs = model_dict.get("saved_cfg")
 
157
 
158
  # Resolve model path and name
159
  custom_mappings = {
160
+ "wav2vec_2" : ("wav2vec_small_960h.pt", "wav2vec"),
161
  "hubert_base": ("hubert_base.pt", "hubert"),
162
  "contentvec_base": ("contentvec_base.pt", "contentvec"),
 
163
  }
164
  if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
165
  model_path, resolved_model_name = custom_mappings[model_name]
 
170
  sys.exit(1)
171
 
172
  # Load model
173
+ model_config = model_configs.get(model_name, model_configs[resolved_model_name])
174
  model_dict = model_config["load_model"](model_path, config.device, config.is_half)
175
  model = model_dict["model"]
176
  additional_configs = model_dict.get("saved_cfg")