File size: 2,230 Bytes
c849174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import sys
import numpy as np
import tempfile
import logging

# This ensures the cloned 'neutts-air' directory is on the Python path
# The Dockerfile places it at /app/neutts-air
neutts_path = "/app/neutts-air"
if neutts_path not in sys.path:
    sys.path.insert(0, neutts_path)

from neuttsair.neutts import NeuTTSAir

logger = logging.getLogger(__name__)

class NeuTTSWrapper:
    def __init__(self, device: str = "auto"):
        """
        Initializes the NeuTTSAir model and its components.
        The model files are expected to be pre-cached in the Docker image.
        """
        if device == "auto":
            # In a real GPU setup, you'd check torch.cuda.is_available()
            # For this project, we'll respect the passed device
            effective_device = "cpu"
        else:
            effective_device = device

        logger.info(f"Initializing NeuTTS Air model on device: {effective_device}...")
        try:
            self.tts_model = NeuTTSAir(
                backbone_repo="neuphonic/neutts-air",
                backbone_device=effective_device,
                codec_repo="neuphonic/neucodec",
                codec_device=effective_device
            )
            self.device = effective_device
            logger.info("✅ NeuTTS Air model initialized successfully.")
        except Exception as e:
            logger.error(f"❌ Failed to initialize NeuTTS Air model: {e}")
            raise

    def encode_reference(self, ref_audio_bytes: bytes):
        """
        Encodes reference audio from in-memory bytes.
        Uses a temporary file as the underlying model requires a file path.
        """
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
            tmp.write(ref_audio_bytes)
            tmp.flush() # Ensure all data is written to the file
            ref_codes = self.tts_model.encode_reference(tmp.name)
        return ref_codes

    def infer(self, gen_text: str, ref_codes, ref_text: str) -> np.ndarray:
        """
        Performs inference using pre-computed reference codes.
        Returns the audio as a NumPy array.
        """
        wav_data = self.tts_model.infer(gen_text, ref_codes, ref_text)
        return wav_data