Spaces:
Sleeping
Sleeping
File size: 1,796 Bytes
ef4db28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from .tts.chatterbox_model import ChatterboxTTSModel
from .tts.kitten_model import KittenTTSModel
from .tts.piper_model import PiperTTSModel
from .tts.kokoro_model import KokoroTTSModel
from .tts.dia_model import DiaTTSModel
from .stt.whisper_model import FasterWhisperSTTModel
class ModelFactory:
"""Factory class for creating model instances"""
@staticmethod
def get_tts_models():
"""Get all available TTS models"""
return {
"ResembleAI/chatterbox": ChatterboxTTSModel(),
"KittenML/KittenTTS": KittenTTSModel(),
"piper-tts": PiperTTSModel(),
"hexgrad/kokoro": KokoroTTSModel(),
"nari-labs/Dia-1.6B": DiaTTSModel()
}
@staticmethod
def get_stt_models():
"""Get all available STT models"""
return {
"SYSTRAN/faster-whisper": FasterWhisperSTTModel()
}
@staticmethod
def get_tts_model(model_name):
"""Get a specific TTS model by name"""
models = ModelFactory.get_tts_models()
return models.get(model_name)
@staticmethod
def get_stt_model(model_name):
"""Get a specific STT model by name"""
models = ModelFactory.get_stt_models()
return models.get(model_name)
@staticmethod
def get_model_descriptions():
"""Get descriptions for all models"""
descriptions = {}
# Add TTS model descriptions
for model_name, model in ModelFactory.get_tts_models().items():
descriptions[model_name] = model.description
# Add STT model descriptions
for model_name, model in ModelFactory.get_stt_models().items():
descriptions[model_name] = model.description
return descriptions |