tts_gallery / src /models /factory.py
Michael Hu
refactor: replace inline model definitions with ModelFactory and remove unused imports
ef4db28
from .tts.chatterbox_model import ChatterboxTTSModel
from .tts.kitten_model import KittenTTSModel
from .tts.piper_model import PiperTTSModel
from .tts.kokoro_model import KokoroTTSModel
from .tts.dia_model import DiaTTSModel
from .stt.whisper_model import FasterWhisperSTTModel
class ModelFactory:
"""Factory class for creating model instances"""
@staticmethod
def get_tts_models():
"""Get all available TTS models"""
return {
"ResembleAI/chatterbox": ChatterboxTTSModel(),
"KittenML/KittenTTS": KittenTTSModel(),
"piper-tts": PiperTTSModel(),
"hexgrad/kokoro": KokoroTTSModel(),
"nari-labs/Dia-1.6B": DiaTTSModel()
}
@staticmethod
def get_stt_models():
"""Get all available STT models"""
return {
"SYSTRAN/faster-whisper": FasterWhisperSTTModel()
}
@staticmethod
def get_tts_model(model_name):
"""Get a specific TTS model by name"""
models = ModelFactory.get_tts_models()
return models.get(model_name)
@staticmethod
def get_stt_model(model_name):
"""Get a specific STT model by name"""
models = ModelFactory.get_stt_models()
return models.get(model_name)
@staticmethod
def get_model_descriptions():
"""Get descriptions for all models"""
descriptions = {}
# Add TTS model descriptions
for model_name, model in ModelFactory.get_tts_models().items():
descriptions[model_name] = model.description
# Add STT model descriptions
for model_name, model in ModelFactory.get_stt_models().items():
descriptions[model_name] = model.description
return descriptions