import torch import numpy as np import io import base64 import subprocess import tempfile import os from typing import Dict, Any from transformers import VitsModel, AutoTokenizer import scipy.io.wavfile as wavfile class EndpointHandler: def __init__(self, path=""): """ Initialize the handler for facebook/mms-tts-asm model """ # Load the model and tokenizer self.model = VitsModel.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path) # Set model to evaluation mode self.model.eval() # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def wav_to_mp3_ffmpeg(self, wav_data: bytes) -> bytes: """ Convert WAV data to MP3 using ffmpeg directly """ try: # Create temporary files with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav: temp_wav.write(wav_data) temp_wav_path = temp_wav.name with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_mp3: temp_mp3_path = temp_mp3.name # Use ffmpeg to convert WAV to MP3 cmd = [ 'ffmpeg', '-y', # -y to overwrite output file '-i', temp_wav_path, # input file '-codec:a', 'libmp3lame', # MP3 encoder '-b:a', '128k', # bitrate '-ar', '16000', # sample rate temp_mp3_path # output file ] # Run ffmpeg result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise Exception(f"FFmpeg error: {result.stderr}") # Read MP3 data with open(temp_mp3_path, 'rb') as f: mp3_data = f.read() # Clean up temporary files os.unlink(temp_wav_path) os.unlink(temp_mp3_path) return mp3_data except Exception as e: # Clean up on error try: if 'temp_wav_path' in locals(): os.unlink(temp_wav_path) if 'temp_mp3_path' in locals(): os.unlink(temp_mp3_path) except: pass raise Exception(f"Error converting to MP3: {str(e)}") def wav_to_mp3_manual(self, wav_data: bytes) -> bytes: """ Alternative: Create a simple MP3-like format manually Note: This creates a basic audio format, not true MP3 """ # This is a simplified approach - not recommended for production # Just wrapping WAV data with minimal MP3-like headers # For true MP3, ffmpeg or similar encoder is needed # Simple ID3v2 header for MP3 id3_header = b'ID3\x03\x00\x00\x00\x00\x00\x00' # Basic MP3 frame header (simplified) mp3_frame_header = b'\xff\xfb\x90\x00' # Combine headers with audio data # Note: This is NOT a proper MP3 file, just a wrapper return id3_header + mp3_frame_header + wav_data def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the request Args: data (Dict): The input data containing text to convert to speech Expected format: {"inputs": "text to convert to speech"} Returns: Dict: Contains the audio file as base64 encoded MP3 """ try: # Extract input text inputs = data.get("inputs", "") if not inputs: return {"error": "No input text provided"} # Additional parameters (optional) parameters = data.get("parameters", {}) conversion_method = parameters.get("conversion_method", "ffmpeg") # "ffmpeg" or "manual" # Process the text with tokenizer input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.device) # Generate speech with torch.no_grad(): output = self.model(input_ids) waveform = output.waveform.squeeze().cpu().numpy() # Convert to audio file sample_rate = 16000 # Normalize audio to prevent clipping if np.max(np.abs(waveform)) > 0: waveform = waveform / np.max(np.abs(waveform)) * 0.95 # Convert to 16-bit PCM waveform_int16 = (waveform * 32767).astype(np.int16) # Create WAV file in memory wav_buffer = io.BytesIO() wavfile.write(wav_buffer, sample_rate, waveform_int16) wav_data = wav_buffer.getvalue() # Convert to MP3 if conversion_method == "ffmpeg": try: mp3_data = self.wav_to_mp3_ffmpeg(wav_data) except Exception as e: # Fallback to manual method if ffmpeg fails print(f"FFmpeg conversion failed: {e}, falling back to manual method") mp3_data = self.wav_to_mp3_manual(wav_data) else: mp3_data = self.wav_to_mp3_manual(wav_data) # Convert to base64 for JSON response audio_base64 = base64.b64encode(mp3_data).decode('utf-8') return { "audio": audio_base64, "sampling_rate": sample_rate, "format": "mp3", "text": inputs, "conversion_method": conversion_method, "content_type": "audio/mpeg" } except Exception as e: return {"error": f"Error processing request: {str(e)}"} # Pure Python MP3 encoder alternative (more complex but no external dependencies) class SimpleLAMEEncoder: """ A very basic MP3-like encoder using pure Python Note: This is a simplified implementation for demonstration For production use, proper MP3 encoding libraries are recommended """ @staticmethod def encode_wav_to_mp3_like(wav_data: bytes, sample_rate: int = 16000) -> bytes: """ Create a basic MP3-like file structure This is a simplified approach and may not be compatible with all players """ # Read WAV header to get audio data wav_io = io.BytesIO(wav_data) # Skip WAV header (44 bytes) wav_io.seek(44) audio_data = wav_io.read() # Create basic MP3 file structure # ID3v2 header id3v2_header = bytearray([ 0x49, 0x44, 0x33, # "ID3" 0x03, 0x00, # Version 2.3 0x00, # Flags 0x00, 0x00, 0x00, 0x00 # Size (will be updated) ]) # Basic MP3 frame header for 16kHz, 128kbps mp3_frame_header = bytearray([ 0xFF, 0xFB, # Sync word and audio version 0x90, 0x00 # Bitrate and sample rate info ]) # Combine to create MP3-like structure result = bytes(id3v2_header) + bytes(mp3_frame_header) + audio_data return result # # Example usage and testing # if __name__ == "__main__": # # Test the handler locally # handler = EndpointHandler("facebook/mms-tts-asm") # # Test input with ffmpeg conversion # test_data = { # "inputs": "Hello, this is a test of the text to speech system.", # "parameters": {"conversion_method": "ffmpeg"} # } # result = handler(test_data) # print("Handler result keys:", result.keys()) # if "audio" in result: # print("MP3 audio generated successfully!") # print(f"Sampling rate: {result['sampling_rate']}") # print(f"Format: {result['format']}") # print(f"Conversion method: {result.get('conversion_method', 'unknown')}") # print(f"Audio data length: {len(result['audio'])} characters (base64)") # # Save the MP3 file for testing # with open("test_output.mp3", "wb") as f: # f.write(base64.b64decode(result['audio'])) # print("Test MP3 saved as 'test_output.mp3'") # else: # print("Error:", result.get("error", "Unknown error")) # # Test with manual conversion method # print("\n--- Testing manual conversion ---") # test_data["parameters"]["conversion_method"] = "manual" # result_manual = handler(test_data) # if "audio" in result_manual: # print("Manual conversion successful!") # with open("test_output_manual.mp3", "wb") as f: # f.write(base64.b64decode(result_manual['audio'])) # print("Manual MP3 saved as 'test_output_manual.mp3'")