Spaces:
Sleeping
Sleeping
Michael Hu
refactor: replace inline model definitions with ModelFactory and remove unused imports
ef4db28
| import tempfile | |
| import os | |
| import soundfile as sf | |
| import numpy as np | |
| from kittentts import KittenTTS | |
| from ..base import TTSModel | |
| class KittenTTSModel(TTSModel): | |
| """KittenTTS model implementation""" | |
| def __init__(self): | |
| self._model = None | |
| self._initialized = False | |
| self._model_path = "KittenML/kitten-tts-nano-0.2" | |
| def name(self): | |
| return "KittenML/KittenTTS" | |
| def description(self): | |
| return "High-quality TTS with voice cloning capabilities using reference audio" | |
| def initialize(self): | |
| """Initialize the KittenTTS model""" | |
| if self._initialized: | |
| return True | |
| try: | |
| self._model = KittenTTS(self._model_path) | |
| self._initialized = True | |
| return True | |
| except Exception as e: | |
| print(f"Error initializing KittenTTS model: {e}") | |
| return False | |
| def generate_speech(self, text, audio_prompt=None, **kwargs): | |
| """ | |
| Generate speech from text using KittenTTS | |
| Args: | |
| text (str): Text to convert to speech | |
| audio_prompt (str, optional): Path to reference audio file for voice cloning | |
| **kwargs: Additional parameters for generation | |
| Returns: | |
| str: Path to the generated audio file | |
| """ | |
| if not self._initialized: | |
| if not self.initialize(): | |
| raise RuntimeError("Failed to initialize KittenTTS model") | |
| # Generate speech using KittenTTS | |
| if audio_prompt and os.path.exists(audio_prompt): | |
| # Use audio prompt for voice cloning | |
| audio_array = self._model.generate_with_voice(text, audio_prompt) | |
| else: | |
| # Generate with default voice | |
| audio_array = self._model.generate(text) | |
| # Save to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| sf.write(tmp_file.name, audio_array, self._model.sample_rate) | |
| return tmp_file.name | |
| def supports_voice_cloning(self): | |
| return True |