from .tts.chatterbox_model import ChatterboxTTSModel from .tts.kitten_model import KittenTTSModel from .tts.piper_model import PiperTTSModel from .tts.kokoro_model import KokoroTTSModel from .tts.dia_model import DiaTTSModel from .stt.whisper_model import FasterWhisperSTTModel class ModelFactory: """Factory class for creating model instances""" @staticmethod def get_tts_models(): """Get all available TTS models""" return { "ResembleAI/chatterbox": ChatterboxTTSModel(), "KittenML/KittenTTS": KittenTTSModel(), "piper-tts": PiperTTSModel(), "hexgrad/kokoro": KokoroTTSModel(), "nari-labs/Dia-1.6B": DiaTTSModel() } @staticmethod def get_stt_models(): """Get all available STT models""" return { "SYSTRAN/faster-whisper": FasterWhisperSTTModel() } @staticmethod def get_tts_model(model_name): """Get a specific TTS model by name""" models = ModelFactory.get_tts_models() return models.get(model_name) @staticmethod def get_stt_model(model_name): """Get a specific STT model by name""" models = ModelFactory.get_stt_models() return models.get(model_name) @staticmethod def get_model_descriptions(): """Get descriptions for all models""" descriptions = {} # Add TTS model descriptions for model_name, model in ModelFactory.get_tts_models().items(): descriptions[model_name] = model.description # Add STT model descriptions for model_name, model in ModelFactory.get_stt_models().items(): descriptions[model_name] = model.description return descriptions