| from wenet.ssl.bestrq.bestrq_model import BestRQModel |
| from wenet.ssl.wav2vec2.wav2vec2_model import Wav2vec2Model |
| from wenet.ssl.w2vbert.w2vbert_model import W2VBERTModel |
|
|
| WENET_SSL_MODEL_CLASS = { |
| "w2vbert_model": W2VBERTModel, |
| "wav2vec_model": Wav2vec2Model, |
| "bestrq_model": BestRQModel |
| } |
|
|
|
|
| def init_model(configs, encoder): |
|
|
| assert 'model' in configs |
| model_type = configs['model'] |
| assert model_type in WENET_SSL_MODEL_CLASS.keys() |
| model = WENET_SSL_MODEL_CLASS[model_type](encoder=encoder, |
| **configs['model_conf']) |
| return model |
|
|