Rajhuggingface4253 commited on
Commit
c849174
·
verified ·
1 Parent(s): 5a4f9b8

Create neutts_wrapper.py

Browse files
Files changed (1) hide show
  1. neutts_wrapper.py +61 -0
neutts_wrapper.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import tempfile
5
+ import logging
6
+
7
+ # This ensures the cloned 'neutts-air' directory is on the Python path
8
+ # The Dockerfile places it at /app/neutts-air
9
+ neutts_path = "/app/neutts-air"
10
+ if neutts_path not in sys.path:
11
+ sys.path.insert(0, neutts_path)
12
+
13
+ from neuttsair.neutts import NeuTTSAir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class NeuTTSWrapper:
18
+ def __init__(self, device: str = "auto"):
19
+ """
20
+ Initializes the NeuTTSAir model and its components.
21
+ The model files are expected to be pre-cached in the Docker image.
22
+ """
23
+ if device == "auto":
24
+ # In a real GPU setup, you'd check torch.cuda.is_available()
25
+ # For this project, we'll respect the passed device
26
+ effective_device = "cpu"
27
+ else:
28
+ effective_device = device
29
+
30
+ logger.info(f"Initializing NeuTTS Air model on device: {effective_device}...")
31
+ try:
32
+ self.tts_model = NeuTTSAir(
33
+ backbone_repo="neuphonic/neutts-air",
34
+ backbone_device=effective_device,
35
+ codec_repo="neuphonic/neucodec",
36
+ codec_device=effective_device
37
+ )
38
+ self.device = effective_device
39
+ logger.info("✅ NeuTTS Air model initialized successfully.")
40
+ except Exception as e:
41
+ logger.error(f"❌ Failed to initialize NeuTTS Air model: {e}")
42
+ raise
43
+
44
+ def encode_reference(self, ref_audio_bytes: bytes):
45
+ """
46
+ Encodes reference audio from in-memory bytes.
47
+ Uses a temporary file as the underlying model requires a file path.
48
+ """
49
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
50
+ tmp.write(ref_audio_bytes)
51
+ tmp.flush() # Ensure all data is written to the file
52
+ ref_codes = self.tts_model.encode_reference(tmp.name)
53
+ return ref_codes
54
+
55
+ def infer(self, gen_text: str, ref_codes, ref_text: str) -> np.ndarray:
56
+ """
57
+ Performs inference using pre-computed reference codes.
58
+ Returns the audio as a NumPy array.
59
+ """
60
+ wav_data = self.tts_model.infer(gen_text, ref_codes, ref_text)
61
+ return wav_data