Update extract_feature_print.py
Browse files- extract_feature_print.py +2 -2
extract_feature_print.py
CHANGED
|
@@ -157,9 +157,9 @@ def main():
|
|
| 157 |
|
| 158 |
# Resolve model path and name
|
| 159 |
custom_mappings = {
|
|
|
|
| 160 |
"hubert_base": ("hubert_base.pt", "hubert"),
|
| 161 |
"contentvec_base": ("contentvec_base.pt", "contentvec"),
|
| 162 |
-
"hubert_large_ll60k": ("hubert_large_ll60k.pt", "hubert"),
|
| 163 |
}
|
| 164 |
if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
|
| 165 |
model_path, resolved_model_name = custom_mappings[model_name]
|
|
@@ -170,7 +170,7 @@ def main():
|
|
| 170 |
sys.exit(1)
|
| 171 |
|
| 172 |
# Load model
|
| 173 |
-
model_config = model_configs.get(model_name, model_configs[
|
| 174 |
model_dict = model_config["load_model"](model_path, config.device, config.is_half)
|
| 175 |
model = model_dict["model"]
|
| 176 |
additional_configs = model_dict.get("saved_cfg")
|
|
|
|
| 157 |
|
| 158 |
# Resolve model path and name
|
| 159 |
custom_mappings = {
|
| 160 |
+
"wav2vec_2" : ("wav2vec_small_960h.pt", "wav2vec"),
|
| 161 |
"hubert_base": ("hubert_base.pt", "hubert"),
|
| 162 |
"contentvec_base": ("contentvec_base.pt", "contentvec"),
|
|
|
|
| 163 |
}
|
| 164 |
if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
|
| 165 |
model_path, resolved_model_name = custom_mappings[model_name]
|
|
|
|
| 170 |
sys.exit(1)
|
| 171 |
|
| 172 |
# Load model
|
| 173 |
+
model_config = model_configs.get(model_name, model_configs[resolved_model_name])
|
| 174 |
model_dict = model_config["load_model"](model_path, config.device, config.is_half)
|
| 175 |
model = model_dict["model"]
|
| 176 |
additional_configs = model_dict.get("saved_cfg")
|