| | import os |
| | import re |
| | from fairseq import checkpoint_utils |
| |
|
| |
|
| | def get_index_path_from_model(sid): |
| | sid0strip = re.sub(r'\.pth|\.onnx$', '', sid) |
| | sid0name = os.path.split(sid0strip)[-1] |
| |
|
| | |
| | if re.match(r'.+_e\d+_s\d+$', sid0name): |
| | base_model_name = sid0name.rsplit('_', 2)[0] |
| | else: |
| | base_model_name = sid0name |
| | |
| | return next( |
| | ( |
| | f |
| | for f in [ |
| | os.path.join(root, name) |
| | for root, _, files in os.walk(os.getenv("index_root"), topdown=False) |
| | for name in files |
| | if name.endswith(".index") and "trained" not in name |
| | ] |
| | if base_model_name in f |
| | ), |
| | "", |
| | ) |
| |
|
| |
|
| | def load_hubert(config): |
| | models, _, _ = checkpoint_utils.load_model_ensemble_and_task( |
| | ["assets/hubert/hubert_base.pt"], |
| | suffix="", |
| | ) |
| | hubert_model = models[0] |
| | hubert_model = hubert_model.to(config.device) |
| | if config.is_half: |
| | hubert_model = hubert_model.half() |
| | else: |
| | hubert_model = hubert_model.float() |
| | return hubert_model.eval() |
| |
|