File size: 9,117 Bytes
f876b9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import torch
import numpy as np
import io
import base64
import subprocess
import tempfile
import os
from typing import Dict, Any
from transformers import VitsModel, AutoTokenizer
import scipy.io.wavfile as wavfile


class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the handler for facebook/mms-tts-asm model
        """
        # Load the model and tokenizer
        self.model = VitsModel.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        
        # Set model to evaluation mode
        self.model.eval()
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
    def wav_to_mp3_ffmpeg(self, wav_data: bytes) -> bytes:
        """
        Convert WAV data to MP3 using ffmpeg directly
        """
        try:
            # Create temporary files
            with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav:
                temp_wav.write(wav_data)
                temp_wav_path = temp_wav.name
            
            with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_mp3:
                temp_mp3_path = temp_mp3.name
            
            # Use ffmpeg to convert WAV to MP3
            cmd = [
                'ffmpeg', '-y',  # -y to overwrite output file
                '-i', temp_wav_path,  # input file
                '-codec:a', 'libmp3lame',  # MP3 encoder
                '-b:a', '128k',  # bitrate
                '-ar', '16000',  # sample rate
                temp_mp3_path  # output file
            ]
            
            # Run ffmpeg
            result = subprocess.run(cmd, capture_output=True, text=True)
            
            if result.returncode != 0:
                raise Exception(f"FFmpeg error: {result.stderr}")
            
            # Read MP3 data
            with open(temp_mp3_path, 'rb') as f:
                mp3_data = f.read()
            
            # Clean up temporary files
            os.unlink(temp_wav_path)
            os.unlink(temp_mp3_path)
            
            return mp3_data
            
        except Exception as e:
            # Clean up on error
            try:
                if 'temp_wav_path' in locals():
                    os.unlink(temp_wav_path)
                if 'temp_mp3_path' in locals():
                    os.unlink(temp_mp3_path)
            except:
                pass
            raise Exception(f"Error converting to MP3: {str(e)}")
    
    def wav_to_mp3_manual(self, wav_data: bytes) -> bytes:
        """
        Alternative: Create a simple MP3-like format manually
        Note: This creates a basic audio format, not true MP3
        """
        # This is a simplified approach - not recommended for production
        # Just wrapping WAV data with minimal MP3-like headers
        # For true MP3, ffmpeg or similar encoder is needed
        
        # Simple ID3v2 header for MP3
        id3_header = b'ID3\x03\x00\x00\x00\x00\x00\x00'
        
        # Basic MP3 frame header (simplified)
        mp3_frame_header = b'\xff\xfb\x90\x00'
        
        # Combine headers with audio data
        # Note: This is NOT a proper MP3 file, just a wrapper
        return id3_header + mp3_frame_header + wav_data

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process the request
        
        Args:
            data (Dict): The input data containing text to convert to speech
            Expected format: {"inputs": "text to convert to speech"}
            
        Returns:
            Dict: Contains the audio file as base64 encoded MP3
        """
        try:
            # Extract input text
            inputs = data.get("inputs", "")
            
            if not inputs:
                return {"error": "No input text provided"}
            
            # Additional parameters (optional)
            parameters = data.get("parameters", {})
            conversion_method = parameters.get("conversion_method", "ffmpeg")  # "ffmpeg" or "manual"
            
            # Process the text with tokenizer
            input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.device)
            
            # Generate speech
            with torch.no_grad():
                output = self.model(input_ids)
                waveform = output.waveform.squeeze().cpu().numpy()
            
            # Convert to audio file
            sample_rate = 16000
            
            # Normalize audio to prevent clipping
            if np.max(np.abs(waveform)) > 0:
                waveform = waveform / np.max(np.abs(waveform)) * 0.95
            
            # Convert to 16-bit PCM
            waveform_int16 = (waveform * 32767).astype(np.int16)
            
            # Create WAV file in memory
            wav_buffer = io.BytesIO()
            wavfile.write(wav_buffer, sample_rate, waveform_int16)
            wav_data = wav_buffer.getvalue()
            
            # Convert to MP3
            if conversion_method == "ffmpeg":
                try:
                    mp3_data = self.wav_to_mp3_ffmpeg(wav_data)
                except Exception as e:
                    # Fallback to manual method if ffmpeg fails
                    print(f"FFmpeg conversion failed: {e}, falling back to manual method")
                    mp3_data = self.wav_to_mp3_manual(wav_data)
            else:
                mp3_data = self.wav_to_mp3_manual(wav_data)
            
            # Convert to base64 for JSON response
            audio_base64 = base64.b64encode(mp3_data).decode('utf-8')
            
            return {
                "audio": audio_base64,
                "sampling_rate": sample_rate,
                "format": "mp3",
                "text": inputs,
                "conversion_method": conversion_method,
                "content_type": "audio/mpeg"
            }
            
        except Exception as e:
            return {"error": f"Error processing request: {str(e)}"}


# Pure Python MP3 encoder alternative (more complex but no external dependencies)
class SimpleLAMEEncoder:
    """
    A very basic MP3-like encoder using pure Python
    Note: This is a simplified implementation for demonstration
    For production use, proper MP3 encoding libraries are recommended
    """
    
    @staticmethod
    def encode_wav_to_mp3_like(wav_data: bytes, sample_rate: int = 16000) -> bytes:
        """
        Create a basic MP3-like file structure
        This is a simplified approach and may not be compatible with all players
        """
        # Read WAV header to get audio data
        wav_io = io.BytesIO(wav_data)
        
        # Skip WAV header (44 bytes)
        wav_io.seek(44)
        audio_data = wav_io.read()
        
        # Create basic MP3 file structure
        # ID3v2 header
        id3v2_header = bytearray([
            0x49, 0x44, 0x33,  # "ID3"
            0x03, 0x00,        # Version 2.3
            0x00,              # Flags
            0x00, 0x00, 0x00, 0x00  # Size (will be updated)
        ])
        
        # Basic MP3 frame header for 16kHz, 128kbps
        mp3_frame_header = bytearray([
            0xFF, 0xFB,  # Sync word and audio version
            0x90, 0x00   # Bitrate and sample rate info
        ])
        
        # Combine to create MP3-like structure
        result = bytes(id3v2_header) + bytes(mp3_frame_header) + audio_data
        
        return result


# # Example usage and testing
# if __name__ == "__main__":
#     # Test the handler locally
#     handler = EndpointHandler("facebook/mms-tts-asm")
    
#     # Test input with ffmpeg conversion
#     test_data = {
#         "inputs": "Hello, this is a test of the text to speech system.",
#         "parameters": {"conversion_method": "ffmpeg"}
#     }
    
#     result = handler(test_data)
#     print("Handler result keys:", result.keys())
    
#     if "audio" in result:
#         print("MP3 audio generated successfully!")
#         print(f"Sampling rate: {result['sampling_rate']}")
#         print(f"Format: {result['format']}")
#         print(f"Conversion method: {result.get('conversion_method', 'unknown')}")
#         print(f"Audio data length: {len(result['audio'])} characters (base64)")
        
#         # Save the MP3 file for testing
#         with open("test_output.mp3", "wb") as f:
#             f.write(base64.b64decode(result['audio']))
#         print("Test MP3 saved as 'test_output.mp3'")
#     else:
#         print("Error:", result.get("error", "Unknown error"))
        
#     # Test with manual conversion method
#     print("\n--- Testing manual conversion ---")
#     test_data["parameters"]["conversion_method"] = "manual"
#     result_manual = handler(test_data)
    
#     if "audio" in result_manual:
#         print("Manual conversion successful!")
#         with open("test_output_manual.mp3", "wb") as f:
#             f.write(base64.b64decode(result_manual['audio']))
#         print("Manual MP3 saved as 'test_output_manual.mp3'")