Spaces:
Runtime error
Runtime error
| import torch | |
| import re | |
| import os | |
| from vita.model.vita_tts.audioLLM import AudioLLM | |
| from vita.model.vita_tts.encoder.cmvn import GlobalCMVN, load_cmvn | |
| from vita.model.vita_tts.encoder.encoder import speechEncoder | |
| def load_checkpoint(model: torch.nn.Module, path: str) -> dict: | |
| if torch.cuda.is_available(): | |
| print('Checkpoint: loading from checkpoint %s for GPU' % path) | |
| checkpoint = torch.load(path) | |
| else: | |
| print('Checkpoint: loading from checkpoint %s for CPU' % path) | |
| checkpoint = torch.load(path, map_location='cpu') | |
| # load parm from checkpoint | |
| model.load_state_dict(checkpoint, strict=False) | |
| info_path = re.sub('.pt$', '.yaml', path) | |
| configs = {} | |
| # get configs | |
| if os.path.exists(info_path): | |
| with open(info_path, 'r') as fin: | |
| configs = yaml.safe_load(fin) | |
| return configs | |
| def init_encoder_llm(configs): | |
| if configs['cmvn_file'] is not None: | |
| # read cmvn | |
| mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) | |
| # init cmvn layer | |
| global_cmvn = GlobalCMVN( | |
| torch.from_numpy(mean).float(), | |
| torch.from_numpy(istd).float()) | |
| else: | |
| global_cmvn = None | |
| input_dim = configs['input_dim'] | |
| vocab_size = configs['output_dim'] | |
| # init speech encoder | |
| encoder = speechEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) | |
| # init audioLLM | |
| model = AudioLLM(encoder=encoder, **configs['model_conf']) | |
| return model | |