File size: 593 Bytes
84ff315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder
import torch.nn as nn

def build_speech_encoder(config):
    speech_encoder_type = getattr(config, 'speech_encoder_type', None)

    print(f"Building speech encoder: {speech_encoder_type}")
    if "whisper" in speech_encoder_type.lower():
        return WhisperWrappedEncoder.load(config)
    elif "dual" in speech_encoder_type.lower():
        return DualWrappedEncoder(config)
    elif "none" in speech_encoder_type.lower():
        return None
    
    raise ValueError(f'Unknown speech encoder: {speech_encoder_type}')