Demo / src /models /utils /model.py
nguyenminh4099's picture
Upload folder using huggingface_hub
9411c06 verified
import torch
from itertools import chain
from argparse import Namespace
from fairseq import checkpoint_utils, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from transformers import AutoTokenizer
from omegaconf import OmegaConf, DictConfig
from src.models.taskers.clustering import HubertFeatureReader
def load_feature_extractor(
ckpt_path: str,
layer,
max_chunk=1600000,
custom_utils=None
):
extractor = HubertFeatureReader(ckpt_path, layer, max_chunk, custom_utils)
return extractor
def load_ensemble_model(cfg_path: str):
def main(main_cfg: DictConfig):
if isinstance(main_cfg, Namespace):
main_cfg = convert_namespace_to_omegaconf(main_cfg)
assert main_cfg.common_eval.path is not None, "--path required for recognition!"
llm_tokenizer = AutoTokenizer.from_pretrained(main_cfg.override.llm_ckpt_path)
try:
utils.import_user_module(main_cfg.common)
except ImportError:
pass
use_cuda = torch.cuda.is_available()
model_override_cfg = {
'model': {
'w2v_path': main_cfg.override.w2v_path,
'llm_ckpt_path': main_cfg.override.llm_ckpt_path,
}
}
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[main_cfg.common_eval.path, ],
model_override_cfg, strict=False
)
models = [model.eval() for model in models]
lms = [None]
for model in chain(models, lms):
if model is None:
continue
if use_cuda and not main_cfg.distributed_training.pipeline_model_parallel:
model.encoder.cuda()
model.avfeat_to_llm.cuda()
model.half()
model = models[0]
return model, main_cfg, saved_cfg, llm_tokenizer
cfg = OmegaConf.load(cfg_path)
return main(main_cfg=cfg)
def load_state_dict_for_extractor(
extractor: torch.nn.Module,
model: torch.nn.Module
):
encoder_state_dict = model.encoder.state_dict()
rename_encoder_state_dict = dict()
for k, v in encoder_state_dict.items():
rename_encoder_state_dict[k[4:]] = v
extractor.load_state_dict(rename_encoder_state_dict, strict=False)