| import os |
| import re |
| from fairseq import checkpoint_utils |
|
|
|
|
| def get_index_path_from_model(sid): |
| sid0strip = re.sub(r'\.pth|\.onnx$', '', sid) |
| sid0name = os.path.split(sid0strip)[-1] |
|
|
| |
| if re.match(r'.+_e\d+_s\d+$', sid0name): |
| base_model_name = sid0name.rsplit('_', 2)[0] |
| else: |
| base_model_name = sid0name |
| |
| return next( |
| ( |
| f |
| for f in [ |
| os.path.join(root, name) |
| for root, _, files in os.walk(os.getenv("index_root"), topdown=False) |
| for name in files |
| if name.endswith(".index") and "trained" not in name |
| ] |
| if base_model_name in f |
| ), |
| "", |
| ) |
|
|
|
|
| def load_hubert(config): |
| models, _, _ = checkpoint_utils.load_model_ensemble_and_task( |
| ["assets/hubert/hubert_base.pt"], |
| suffix="", |
| ) |
| hubert_model = models[0] |
| hubert_model = hubert_model.to(config.device) |
| if config.is_half: |
| hubert_model = hubert_model.half() |
| else: |
| hubert_model = hubert_model.float() |
| return hubert_model.eval() |
|
|