Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import torch | |
| from wenet.osum_echat.init_llmasr import init_llmasr | |
| from wenet.transformer.asr_model import ASRModel | |
| from wenet.transformer.cmvn import GlobalCMVN | |
| from wenet.transformer.ctc import CTC | |
| from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder | |
| from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder | |
| from wenet.whisper.whisper import Whisper | |
| from wenet.utils.cmvn import load_cmvn | |
| from wenet.utils.checkpoint import load_checkpoint, load_trained_modules | |
| WENET_ENCODER_CLASSES = { | |
| "transformer": TransformerEncoder, | |
| "conformer": ConformerEncoder, | |
| } | |
| WENET_DECODER_CLASSES = { | |
| "transformer": TransformerDecoder, | |
| "bitransformer": BiTransformerDecoder, | |
| } | |
| WENET_CTC_CLASSES = { | |
| "ctc": CTC, | |
| } | |
| WENET_MODEL_CLASSES = { | |
| "asr_model": ASRModel, | |
| "whisper": Whisper, | |
| } | |
| def init_speech_model(args, configs): | |
| # TODO(xcsong): Forcefully read the 'cmvn' attribute. | |
| if configs.get('cmvn', None) == 'global_cmvn': | |
| mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], | |
| configs['cmvn_conf']['is_json_cmvn']) | |
| global_cmvn = GlobalCMVN( | |
| torch.from_numpy(mean).float(), | |
| torch.from_numpy(istd).float()) | |
| else: | |
| global_cmvn = None | |
| input_dim = configs['input_dim'] | |
| vocab_size = configs['output_dim'] | |
| encoder_type = configs.get('encoder', 'conformer') | |
| decoder_type = configs.get('decoder', 'bitransformer') | |
| ctc_type = configs.get('ctc', 'ctc') | |
| encoder = WENET_ENCODER_CLASSES[encoder_type]( | |
| input_dim, | |
| global_cmvn=global_cmvn, | |
| **configs['encoder_conf'], | |
| **configs['encoder_conf']['efficient_conf'] | |
| if 'efficient_conf' in configs['encoder_conf'] else {}) | |
| decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, | |
| encoder.output_size(), | |
| **configs['decoder_conf']) | |
| ctc = WENET_CTC_CLASSES[ctc_type]( | |
| vocab_size, | |
| encoder.output_size(), | |
| blank_id=configs['ctc_conf']['ctc_blank_id'] | |
| if 'ctc_conf' in configs else 0) | |
| model_type = configs.get('model', 'asr_model') | |
| model = WENET_MODEL_CLASSES[model_type]( | |
| vocab_size=vocab_size, | |
| encoder=encoder, | |
| decoder=decoder, | |
| ctc=ctc, | |
| special_tokens=configs.get('tokenizer_conf', | |
| {}).get('special_tokens', None), | |
| **configs['model_conf']) | |
| return model, configs | |
| def init_model(args, configs): | |
| model_type = configs.get('model', 'asr_model') | |
| configs['model'] = model_type | |
| if model_type == "osum_echat": | |
| is_inference =configs.get('is_inference', False) | |
| model = init_llmasr(args, configs, is_inference=is_inference) | |
| return model | |
| else: | |
| model, configs = init_speech_model(args, configs) | |
| # If specify checkpoint, load some info from checkpoint | |
| if hasattr(args, 'checkpoint') and args.checkpoint is not None: | |
| infos = load_checkpoint(model, args.checkpoint) | |
| elif hasattr(args, 'enc_init') and args.enc_init is not None: | |
| infos = load_trained_modules(model, args) | |
| else: | |
| infos = {} | |
| if configs.get('init_step', False): | |
| infos = {} | |
| configs["init_infos"] = infos | |
| if hasattr(args, 'use_lora') and args.use_lora: | |
| if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path: | |
| load_checkpoint(model, args.lora_ckpt_path) | |
| print(configs) | |
| # Trye to tie some weights | |
| if hasattr(model, 'tie_or_clone_weights'): | |
| if not hasattr(args, 'jit'): | |
| args.jit = True # i.e. export onnx/jit/ipex | |
| model.tie_or_clone_weights(args.jit) | |
| if int(os.environ.get('RANK', 0)) == 0: | |
| print(configs) | |
| return model, configs | |