File size: 9,117 Bytes
f876b9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
import torch
import numpy as np
import io
import base64
import subprocess
import tempfile
import os
from typing import Dict, Any
from transformers import VitsModel, AutoTokenizer
import scipy.io.wavfile as wavfile
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler for facebook/mms-tts-asm model
"""
# Load the model and tokenizer
self.model = VitsModel.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Set model to evaluation mode
self.model.eval()
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def wav_to_mp3_ffmpeg(self, wav_data: bytes) -> bytes:
"""
Convert WAV data to MP3 using ffmpeg directly
"""
try:
# Create temporary files
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav:
temp_wav.write(wav_data)
temp_wav_path = temp_wav.name
with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_mp3:
temp_mp3_path = temp_mp3.name
# Use ffmpeg to convert WAV to MP3
cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-i', temp_wav_path, # input file
'-codec:a', 'libmp3lame', # MP3 encoder
'-b:a', '128k', # bitrate
'-ar', '16000', # sample rate
temp_mp3_path # output file
]
# Run ffmpeg
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"FFmpeg error: {result.stderr}")
# Read MP3 data
with open(temp_mp3_path, 'rb') as f:
mp3_data = f.read()
# Clean up temporary files
os.unlink(temp_wav_path)
os.unlink(temp_mp3_path)
return mp3_data
except Exception as e:
# Clean up on error
try:
if 'temp_wav_path' in locals():
os.unlink(temp_wav_path)
if 'temp_mp3_path' in locals():
os.unlink(temp_mp3_path)
except:
pass
raise Exception(f"Error converting to MP3: {str(e)}")
def wav_to_mp3_manual(self, wav_data: bytes) -> bytes:
"""
Alternative: Create a simple MP3-like format manually
Note: This creates a basic audio format, not true MP3
"""
# This is a simplified approach - not recommended for production
# Just wrapping WAV data with minimal MP3-like headers
# For true MP3, ffmpeg or similar encoder is needed
# Simple ID3v2 header for MP3
id3_header = b'ID3\x03\x00\x00\x00\x00\x00\x00'
# Basic MP3 frame header (simplified)
mp3_frame_header = b'\xff\xfb\x90\x00'
# Combine headers with audio data
# Note: This is NOT a proper MP3 file, just a wrapper
return id3_header + mp3_frame_header + wav_data
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the request
Args:
data (Dict): The input data containing text to convert to speech
Expected format: {"inputs": "text to convert to speech"}
Returns:
Dict: Contains the audio file as base64 encoded MP3
"""
try:
# Extract input text
inputs = data.get("inputs", "")
if not inputs:
return {"error": "No input text provided"}
# Additional parameters (optional)
parameters = data.get("parameters", {})
conversion_method = parameters.get("conversion_method", "ffmpeg") # "ffmpeg" or "manual"
# Process the text with tokenizer
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.device)
# Generate speech
with torch.no_grad():
output = self.model(input_ids)
waveform = output.waveform.squeeze().cpu().numpy()
# Convert to audio file
sample_rate = 16000
# Normalize audio to prevent clipping
if np.max(np.abs(waveform)) > 0:
waveform = waveform / np.max(np.abs(waveform)) * 0.95
# Convert to 16-bit PCM
waveform_int16 = (waveform * 32767).astype(np.int16)
# Create WAV file in memory
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, sample_rate, waveform_int16)
wav_data = wav_buffer.getvalue()
# Convert to MP3
if conversion_method == "ffmpeg":
try:
mp3_data = self.wav_to_mp3_ffmpeg(wav_data)
except Exception as e:
# Fallback to manual method if ffmpeg fails
print(f"FFmpeg conversion failed: {e}, falling back to manual method")
mp3_data = self.wav_to_mp3_manual(wav_data)
else:
mp3_data = self.wav_to_mp3_manual(wav_data)
# Convert to base64 for JSON response
audio_base64 = base64.b64encode(mp3_data).decode('utf-8')
return {
"audio": audio_base64,
"sampling_rate": sample_rate,
"format": "mp3",
"text": inputs,
"conversion_method": conversion_method,
"content_type": "audio/mpeg"
}
except Exception as e:
return {"error": f"Error processing request: {str(e)}"}
# Pure Python MP3 encoder alternative (more complex but no external dependencies)
class SimpleLAMEEncoder:
"""
A very basic MP3-like encoder using pure Python
Note: This is a simplified implementation for demonstration
For production use, proper MP3 encoding libraries are recommended
"""
@staticmethod
def encode_wav_to_mp3_like(wav_data: bytes, sample_rate: int = 16000) -> bytes:
"""
Create a basic MP3-like file structure
This is a simplified approach and may not be compatible with all players
"""
# Read WAV header to get audio data
wav_io = io.BytesIO(wav_data)
# Skip WAV header (44 bytes)
wav_io.seek(44)
audio_data = wav_io.read()
# Create basic MP3 file structure
# ID3v2 header
id3v2_header = bytearray([
0x49, 0x44, 0x33, # "ID3"
0x03, 0x00, # Version 2.3
0x00, # Flags
0x00, 0x00, 0x00, 0x00 # Size (will be updated)
])
# Basic MP3 frame header for 16kHz, 128kbps
mp3_frame_header = bytearray([
0xFF, 0xFB, # Sync word and audio version
0x90, 0x00 # Bitrate and sample rate info
])
# Combine to create MP3-like structure
result = bytes(id3v2_header) + bytes(mp3_frame_header) + audio_data
return result
# # Example usage and testing
# if __name__ == "__main__":
# # Test the handler locally
# handler = EndpointHandler("facebook/mms-tts-asm")
# # Test input with ffmpeg conversion
# test_data = {
# "inputs": "Hello, this is a test of the text to speech system.",
# "parameters": {"conversion_method": "ffmpeg"}
# }
# result = handler(test_data)
# print("Handler result keys:", result.keys())
# if "audio" in result:
# print("MP3 audio generated successfully!")
# print(f"Sampling rate: {result['sampling_rate']}")
# print(f"Format: {result['format']}")
# print(f"Conversion method: {result.get('conversion_method', 'unknown')}")
# print(f"Audio data length: {len(result['audio'])} characters (base64)")
# # Save the MP3 file for testing
# with open("test_output.mp3", "wb") as f:
# f.write(base64.b64decode(result['audio']))
# print("Test MP3 saved as 'test_output.mp3'")
# else:
# print("Error:", result.get("error", "Unknown error"))
# # Test with manual conversion method
# print("\n--- Testing manual conversion ---")
# test_data["parameters"]["conversion_method"] = "manual"
# result_manual = handler(test_data)
# if "audio" in result_manual:
# print("Manual conversion successful!")
# with open("test_output_manual.mp3", "wb") as f:
# f.write(base64.b64decode(result_manual['audio']))
# print("Manual MP3 saved as 'test_output_manual.mp3'") |