Spaces:
Paused
Paused
| import os | |
| import sys | |
| import numpy as np | |
| import tempfile | |
| import logging | |
| # This ensures the cloned 'neutts-air' directory is on the Python path | |
| # The Dockerfile places it at /app/neutts-air | |
| neutts_path = "/app/neutts-air" | |
| if neutts_path not in sys.path: | |
| sys.path.insert(0, neutts_path) | |
| from neuttsair.neutts import NeuTTSAir | |
| logger = logging.getLogger(__name__) | |
| class NeuTTSWrapper: | |
| def __init__(self, device: str = "auto"): | |
| """ | |
| Initializes the NeuTTSAir model and its components. | |
| The model files are expected to be pre-cached in the Docker image. | |
| """ | |
| if device == "auto": | |
| # In a real GPU setup, you'd check torch.cuda.is_available() | |
| # For this project, we'll respect the passed device | |
| effective_device = "cpu" | |
| else: | |
| effective_device = device | |
| logger.info(f"Initializing NeuTTS Air model on device: {effective_device}...") | |
| try: | |
| self.tts_model = NeuTTSAir( | |
| backbone_repo="neuphonic/neutts-air", | |
| backbone_device=effective_device, | |
| codec_repo="neuphonic/neucodec", | |
| codec_device=effective_device | |
| ) | |
| self.device = effective_device | |
| logger.info("✅ NeuTTS Air model initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to initialize NeuTTS Air model: {e}") | |
| raise | |
| def encode_reference(self, ref_audio_bytes: bytes): | |
| """ | |
| Encodes reference audio from in-memory bytes. | |
| Uses a temporary file as the underlying model requires a file path. | |
| """ | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp: | |
| tmp.write(ref_audio_bytes) | |
| tmp.flush() # Ensure all data is written to the file | |
| ref_codes = self.tts_model.encode_reference(tmp.name) | |
| return ref_codes | |
| def infer(self, gen_text: str, ref_codes, ref_text: str) -> np.ndarray: | |
| """ | |
| Performs inference using pre-computed reference codes. | |
| Returns the audio as a NumPy array. | |
| """ | |
| wav_data = self.tts_model.infer(gen_text, ref_codes, ref_text) | |
| return wav_data |