File size: 648 Bytes
0b53d99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import logging
from typing import Dict, List, Union

from src.utils.TTS.utils.generic_utils import find_module

logger = logging.getLogger(__name__)


def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
    logger.info("Using model: %s", config.model)
    # fetch the right model implementation.
    if "base_model" in config and config["base_model"] is not None:
        MyModel = find_module("TTS.tts.models", config.base_model.lower())
    else:
        MyModel = find_module("TTS.tts.models", config.model.lower())
    model = MyModel.init_from_config(config=config, samples=samples)
    return model