Spaces:
Paused
Paused
| import torch | |
| from itertools import chain | |
| from argparse import Namespace | |
| from fairseq import checkpoint_utils, utils | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from transformers import AutoTokenizer | |
| from omegaconf import OmegaConf, DictConfig | |
| from src.models.taskers.clustering import HubertFeatureReader | |
| def load_feature_extractor( | |
| ckpt_path: str, | |
| layer, | |
| max_chunk=1600000, | |
| custom_utils=None | |
| ): | |
| extractor = HubertFeatureReader(ckpt_path, layer, max_chunk, custom_utils) | |
| return extractor | |
| def load_ensemble_model(cfg_path: str): | |
| def main(main_cfg: DictConfig): | |
| if isinstance(main_cfg, Namespace): | |
| main_cfg = convert_namespace_to_omegaconf(main_cfg) | |
| assert main_cfg.common_eval.path is not None, "--path required for recognition!" | |
| llm_tokenizer = AutoTokenizer.from_pretrained(main_cfg.override.llm_ckpt_path) | |
| try: | |
| utils.import_user_module(main_cfg.common) | |
| except ImportError: | |
| pass | |
| use_cuda = torch.cuda.is_available() | |
| model_override_cfg = { | |
| 'model': { | |
| 'w2v_path': main_cfg.override.w2v_path, | |
| 'llm_ckpt_path': main_cfg.override.llm_ckpt_path, | |
| } | |
| } | |
| models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
| [main_cfg.common_eval.path, ], | |
| model_override_cfg, strict=False | |
| ) | |
| models = [model.eval() for model in models] | |
| lms = [None] | |
| for model in chain(models, lms): | |
| if model is None: | |
| continue | |
| if use_cuda and not main_cfg.distributed_training.pipeline_model_parallel: | |
| model.encoder.cuda() | |
| model.avfeat_to_llm.cuda() | |
| model.half() | |
| model = models[0] | |
| return model, main_cfg, saved_cfg, llm_tokenizer | |
| cfg = OmegaConf.load(cfg_path) | |
| return main(main_cfg=cfg) | |
| def load_state_dict_for_extractor( | |
| extractor: torch.nn.Module, | |
| model: torch.nn.Module | |
| ): | |
| encoder_state_dict = model.encoder.state_dict() | |
| rename_encoder_state_dict = dict() | |
| for k, v in encoder_state_dict.items(): | |
| rename_encoder_state_dict[k[4:]] = v | |
| extractor.load_state_dict(rename_encoder_state_dict, strict=False) |