File size: 2,327 Bytes
9411c06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from itertools import chain
from argparse import Namespace
from fairseq import checkpoint_utils, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from transformers import AutoTokenizer
from omegaconf import OmegaConf, DictConfig
from src.models.taskers.clustering import HubertFeatureReader


def load_feature_extractor(
        ckpt_path: str,
        layer,
        max_chunk=1600000,
        custom_utils=None
):
    extractor = HubertFeatureReader(ckpt_path, layer, max_chunk, custom_utils)
    return extractor


def load_ensemble_model(cfg_path: str):
    def main(main_cfg: DictConfig):
        if isinstance(main_cfg, Namespace):
            main_cfg = convert_namespace_to_omegaconf(main_cfg)
        assert main_cfg.common_eval.path is not None, "--path required for recognition!"
        llm_tokenizer = AutoTokenizer.from_pretrained(main_cfg.override.llm_ckpt_path)
        try:
            utils.import_user_module(main_cfg.common)
        except ImportError:
            pass

        use_cuda = torch.cuda.is_available()

        model_override_cfg = {
            'model': {
                'w2v_path': main_cfg.override.w2v_path,
                'llm_ckpt_path': main_cfg.override.llm_ckpt_path,
            }
        }
        models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
            [main_cfg.common_eval.path, ],
            model_override_cfg, strict=False
        )

        models = [model.eval() for model in models]
        lms = [None]
        for model in chain(models, lms):
            if model is None:
                continue
            if use_cuda and not main_cfg.distributed_training.pipeline_model_parallel:
                model.encoder.cuda()
                model.avfeat_to_llm.cuda()
                model.half()
        model = models[0]

        return model, main_cfg, saved_cfg, llm_tokenizer

    cfg = OmegaConf.load(cfg_path)
    return main(main_cfg=cfg)


def load_state_dict_for_extractor(
        extractor: torch.nn.Module,
        model: torch.nn.Module
):
    encoder_state_dict = model.encoder.state_dict()
    rename_encoder_state_dict = dict()
    for k, v in encoder_state_dict.items():
        rename_encoder_state_dict[k[4:]] = v

    extractor.load_state_dict(rename_encoder_state_dict, strict=False)