Files changed (1) hide show
  1. handler.py +249 -0
handler.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import io
4
+ import base64
5
+ import subprocess
6
+ import tempfile
7
+ import os
8
+ from typing import Dict, Any
9
+ from transformers import VitsModel, AutoTokenizer
10
+ import scipy.io.wavfile as wavfile
11
+
12
+
13
+ class EndpointHandler:
14
+ def __init__(self, path=""):
15
+ """
16
+ Initialize the handler for facebook/mms-tts-asm model
17
+ """
18
+ # Load the model and tokenizer
19
+ self.model = VitsModel.from_pretrained(path)
20
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
21
+
22
+ # Set model to evaluation mode
23
+ self.model.eval()
24
+
25
+ # Set device
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.model.to(self.device)
28
+
29
+ def wav_to_mp3_ffmpeg(self, wav_data: bytes) -> bytes:
30
+ """
31
+ Convert WAV data to MP3 using ffmpeg directly
32
+ """
33
+ try:
34
+ # Create temporary files
35
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav:
36
+ temp_wav.write(wav_data)
37
+ temp_wav_path = temp_wav.name
38
+
39
+ with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_mp3:
40
+ temp_mp3_path = temp_mp3.name
41
+
42
+ # Use ffmpeg to convert WAV to MP3
43
+ cmd = [
44
+ 'ffmpeg', '-y', # -y to overwrite output file
45
+ '-i', temp_wav_path, # input file
46
+ '-codec:a', 'libmp3lame', # MP3 encoder
47
+ '-b:a', '128k', # bitrate
48
+ '-ar', '16000', # sample rate
49
+ temp_mp3_path # output file
50
+ ]
51
+
52
+ # Run ffmpeg
53
+ result = subprocess.run(cmd, capture_output=True, text=True)
54
+
55
+ if result.returncode != 0:
56
+ raise Exception(f"FFmpeg error: {result.stderr}")
57
+
58
+ # Read MP3 data
59
+ with open(temp_mp3_path, 'rb') as f:
60
+ mp3_data = f.read()
61
+
62
+ # Clean up temporary files
63
+ os.unlink(temp_wav_path)
64
+ os.unlink(temp_mp3_path)
65
+
66
+ return mp3_data
67
+
68
+ except Exception as e:
69
+ # Clean up on error
70
+ try:
71
+ if 'temp_wav_path' in locals():
72
+ os.unlink(temp_wav_path)
73
+ if 'temp_mp3_path' in locals():
74
+ os.unlink(temp_mp3_path)
75
+ except:
76
+ pass
77
+ raise Exception(f"Error converting to MP3: {str(e)}")
78
+
79
+ def wav_to_mp3_manual(self, wav_data: bytes) -> bytes:
80
+ """
81
+ Alternative: Create a simple MP3-like format manually
82
+ Note: This creates a basic audio format, not true MP3
83
+ """
84
+ # This is a simplified approach - not recommended for production
85
+ # Just wrapping WAV data with minimal MP3-like headers
86
+ # For true MP3, ffmpeg or similar encoder is needed
87
+
88
+ # Simple ID3v2 header for MP3
89
+ id3_header = b'ID3\x03\x00\x00\x00\x00\x00\x00'
90
+
91
+ # Basic MP3 frame header (simplified)
92
+ mp3_frame_header = b'\xff\xfb\x90\x00'
93
+
94
+ # Combine headers with audio data
95
+ # Note: This is NOT a proper MP3 file, just a wrapper
96
+ return id3_header + mp3_frame_header + wav_data
97
+
98
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
99
+ """
100
+ Process the request
101
+
102
+ Args:
103
+ data (Dict): The input data containing text to convert to speech
104
+ Expected format: {"inputs": "text to convert to speech"}
105
+
106
+ Returns:
107
+ Dict: Contains the audio file as base64 encoded MP3
108
+ """
109
+ try:
110
+ # Extract input text
111
+ inputs = data.get("inputs", "")
112
+
113
+ if not inputs:
114
+ return {"error": "No input text provided"}
115
+
116
+ # Additional parameters (optional)
117
+ parameters = data.get("parameters", {})
118
+ conversion_method = parameters.get("conversion_method", "ffmpeg") # "ffmpeg" or "manual"
119
+
120
+ # Process the text with tokenizer
121
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.device)
122
+
123
+ # Generate speech
124
+ with torch.no_grad():
125
+ output = self.model(input_ids)
126
+ waveform = output.waveform.squeeze().cpu().numpy()
127
+
128
+ # Convert to audio file
129
+ sample_rate = 16000
130
+
131
+ # Normalize audio to prevent clipping
132
+ if np.max(np.abs(waveform)) > 0:
133
+ waveform = waveform / np.max(np.abs(waveform)) * 0.95
134
+
135
+ # Convert to 16-bit PCM
136
+ waveform_int16 = (waveform * 32767).astype(np.int16)
137
+
138
+ # Create WAV file in memory
139
+ wav_buffer = io.BytesIO()
140
+ wavfile.write(wav_buffer, sample_rate, waveform_int16)
141
+ wav_data = wav_buffer.getvalue()
142
+
143
+ # Convert to MP3
144
+ if conversion_method == "ffmpeg":
145
+ try:
146
+ mp3_data = self.wav_to_mp3_ffmpeg(wav_data)
147
+ except Exception as e:
148
+ # Fallback to manual method if ffmpeg fails
149
+ print(f"FFmpeg conversion failed: {e}, falling back to manual method")
150
+ mp3_data = self.wav_to_mp3_manual(wav_data)
151
+ else:
152
+ mp3_data = self.wav_to_mp3_manual(wav_data)
153
+
154
+ # Convert to base64 for JSON response
155
+ audio_base64 = base64.b64encode(mp3_data).decode('utf-8')
156
+
157
+ return {
158
+ "audio": audio_base64,
159
+ "sampling_rate": sample_rate,
160
+ "format": "mp3",
161
+ "text": inputs,
162
+ "conversion_method": conversion_method,
163
+ "content_type": "audio/mpeg"
164
+ }
165
+
166
+ except Exception as e:
167
+ return {"error": f"Error processing request: {str(e)}"}
168
+
169
+
170
+ # Pure Python MP3 encoder alternative (more complex but no external dependencies)
171
+ class SimpleLAMEEncoder:
172
+ """
173
+ A very basic MP3-like encoder using pure Python
174
+ Note: This is a simplified implementation for demonstration
175
+ For production use, proper MP3 encoding libraries are recommended
176
+ """
177
+
178
+ @staticmethod
179
+ def encode_wav_to_mp3_like(wav_data: bytes, sample_rate: int = 16000) -> bytes:
180
+ """
181
+ Create a basic MP3-like file structure
182
+ This is a simplified approach and may not be compatible with all players
183
+ """
184
+ # Read WAV header to get audio data
185
+ wav_io = io.BytesIO(wav_data)
186
+
187
+ # Skip WAV header (44 bytes)
188
+ wav_io.seek(44)
189
+ audio_data = wav_io.read()
190
+
191
+ # Create basic MP3 file structure
192
+ # ID3v2 header
193
+ id3v2_header = bytearray([
194
+ 0x49, 0x44, 0x33, # "ID3"
195
+ 0x03, 0x00, # Version 2.3
196
+ 0x00, # Flags
197
+ 0x00, 0x00, 0x00, 0x00 # Size (will be updated)
198
+ ])
199
+
200
+ # Basic MP3 frame header for 16kHz, 128kbps
201
+ mp3_frame_header = bytearray([
202
+ 0xFF, 0xFB, # Sync word and audio version
203
+ 0x90, 0x00 # Bitrate and sample rate info
204
+ ])
205
+
206
+ # Combine to create MP3-like structure
207
+ result = bytes(id3v2_header) + bytes(mp3_frame_header) + audio_data
208
+
209
+ return result
210
+
211
+
212
+ # # Example usage and testing
213
+ # if __name__ == "__main__":
214
+ # # Test the handler locally
215
+ # handler = EndpointHandler("facebook/mms-tts-asm")
216
+
217
+ # # Test input with ffmpeg conversion
218
+ # test_data = {
219
+ # "inputs": "Hello, this is a test of the text to speech system.",
220
+ # "parameters": {"conversion_method": "ffmpeg"}
221
+ # }
222
+
223
+ # result = handler(test_data)
224
+ # print("Handler result keys:", result.keys())
225
+
226
+ # if "audio" in result:
227
+ # print("MP3 audio generated successfully!")
228
+ # print(f"Sampling rate: {result['sampling_rate']}")
229
+ # print(f"Format: {result['format']}")
230
+ # print(f"Conversion method: {result.get('conversion_method', 'unknown')}")
231
+ # print(f"Audio data length: {len(result['audio'])} characters (base64)")
232
+
233
+ # # Save the MP3 file for testing
234
+ # with open("test_output.mp3", "wb") as f:
235
+ # f.write(base64.b64decode(result['audio']))
236
+ # print("Test MP3 saved as 'test_output.mp3'")
237
+ # else:
238
+ # print("Error:", result.get("error", "Unknown error"))
239
+
240
+ # # Test with manual conversion method
241
+ # print("\n--- Testing manual conversion ---")
242
+ # test_data["parameters"]["conversion_method"] = "manual"
243
+ # result_manual = handler(test_data)
244
+
245
+ # if "audio" in result_manual:
246
+ # print("Manual conversion successful!")
247
+ # with open("test_output_manual.mp3", "wb") as f:
248
+ # f.write(base64.b64decode(result_manual['audio']))
249
+ # print("Manual MP3 saved as 'test_output_manual.mp3'")