Arabic-TTS-Spark / handler.py
it-support-mumz's picture
Fork Arabic-TTS-Spark with HF Inference Endpoints handler
a6afb46 verified
"""
Custom Handler for Hugging Face Inference Endpoints
Model: IbrahimSalah/Arabic-TTS-Spark
Repository: https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark
This handler provides Text-to-Speech inference for Arabic with:
- Voice cloning (with reference audio)
- Controllable TTS (with gender, pitch, speed parameters)
"""
import base64
import io
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import soundfile as sf
import torch
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Hugging Face Inference Endpoints handler for Arabic-TTS-Spark.
Supports two modes:
1. Voice Cloning: Provide reference audio to clone the voice
2. Controllable TTS: Specify gender, pitch, and speed parameters
"""
def __init__(self, path: str = ""):
"""
Initialize the handler by loading the model and processor.
Args:
path: Path to the model directory (provided by HF Inference Endpoints)
"""
from transformers import AutoModel, AutoProcessor
self.device = self._get_device()
logger.info(f"Initializing Arabic-TTS-Spark on device: {self.device}")
# Determine model path
model_path = path if path else "IbrahimSalah/Arabic-TTS-Spark"
logger.info(f"Loading model from: {model_path}")
# Load processor and model with trust_remote_code=True (required for custom classes)
self.processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32
)
# Move model to device and set to eval mode
self.model = self.model.to(self.device).eval()
# Link processor to model (required for voice cloning)
self.processor.link_model(self.model)
# Store default reference audio path
self.default_reference_path = Path(model_path) / "reference.wav"
if not self.default_reference_path.exists():
# Try to find it in the resolved path
self.default_reference_path = Path(path) / "reference.wav" if path else None
logger.info("Model loaded successfully")
def _get_device(self) -> torch.device:
"""Determine the best available device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def _decode_audio_base64(self, audio_base64: str) -> tuple:
"""
Decode base64 audio to numpy array.
Args:
audio_base64: Base64 encoded audio data
Returns:
Tuple of (audio_data, sample_rate)
"""
audio_bytes = base64.b64decode(audio_base64)
audio_buffer = io.BytesIO(audio_bytes)
audio_data, sample_rate = sf.read(audio_buffer)
return audio_data, sample_rate
def _encode_audio_base64(self, audio_data: np.ndarray, sample_rate: int) -> str:
"""
Encode audio numpy array to base64.
Args:
audio_data: Audio waveform as numpy array
sample_rate: Sample rate of the audio
Returns:
Base64 encoded audio string
"""
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
audio_buffer.seek(0)
return base64.b64encode(audio_buffer.read()).decode('utf-8')
def _validate_inputs(self, data: Dict[str, Any]) -> tuple:
"""
Validate and extract inputs from request data.
Args:
data: Request data dictionary
Returns:
Tuple of (text, parameters, mode)
"""
# Extract text input
text = data.get("inputs", "")
if not text:
raise ValueError("No input text provided. Use 'inputs' field.")
# Extract parameters
parameters = data.get("parameters", {})
# Determine mode
has_audio = "prompt_audio_base64" in parameters or "prompt_audio" in parameters
has_control = all(k in parameters for k in ["gender", "pitch", "speed"])
if has_audio:
mode = "voice_cloning"
elif has_control:
mode = "controllable"
else:
# Default to controllable with default parameters
mode = "controllable"
parameters.setdefault("gender", "male")
parameters.setdefault("pitch", "moderate")
parameters.setdefault("speed", "moderate")
return text, parameters, mode
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process inference request.
Args:
data: Request data with the following structure:
{
"inputs": "Arabic text with diacritics",
"parameters": {
# For voice cloning:
"prompt_audio_base64": "<base64-wav>", # or "prompt_audio"
"prompt_text": "reference transcript",
# For controllable TTS:
"gender": "male" or "female",
"pitch": "very_low", "low", "moderate", "high", "very_high",
"speed": "very_low", "low", "moderate", "high", "very_high",
# Generation parameters (optional):
"temperature": 0.8,
"max_new_tokens": 3000,
"top_p": 0.95,
"top_k": 50
}
}
Returns:
Dictionary with:
{
"audio": "<base64-encoded-wav>",
"sampling_rate": 16000
}
"""
try:
# Validate inputs
text, parameters, mode = self._validate_inputs(data)
logger.info(f"Processing request - Mode: {mode}, Text length: {len(text)}")
# Extract generation parameters
temperature = parameters.get("temperature", 0.8)
max_new_tokens = parameters.get("max_new_tokens", 3000)
top_p = parameters.get("top_p", 0.95)
top_k = parameters.get("top_k", 50)
# Prepare processor inputs based on mode
if mode == "voice_cloning":
# Handle voice cloning mode
audio_base64 = parameters.get("prompt_audio_base64") or parameters.get("prompt_audio")
prompt_text = parameters.get("prompt_text", "")
# Save audio to temporary file (processor expects file path)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
audio_data, _ = self._decode_audio_base64(audio_base64)
sf.write(tmp_file.name, audio_data, 16000)
tmp_audio_path = tmp_file.name
try:
# Process inputs for voice cloning
inputs = self.processor(
text=text,
prompt_speech_path=tmp_audio_path,
prompt_text=prompt_text if prompt_text else None,
return_tensors="pt"
)
finally:
# Clean up temporary file
os.unlink(tmp_audio_path)
else:
# Handle controllable TTS mode
gender = parameters.get("gender", "male")
pitch = parameters.get("pitch", "moderate")
speed = parameters.get("speed", "moderate")
# Validate parameter values
valid_genders = ["male", "female"]
valid_levels = ["very_low", "low", "moderate", "high", "very_high"]
if gender not in valid_genders:
raise ValueError(f"Invalid gender: {gender}. Must be one of {valid_genders}")
if pitch not in valid_levels:
raise ValueError(f"Invalid pitch: {pitch}. Must be one of {valid_levels}")
if speed not in valid_levels:
raise ValueError(f"Invalid speed: {speed}. Must be one of {valid_levels}")
# Process inputs for controllable TTS
inputs = self.processor(
text=text,
gender=gender,
pitch=pitch,
speed=speed,
return_tensors="pt"
)
# Move inputs to device
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Store input length for decoding
input_ids_len = inputs["input_ids"].shape[1]
# Generate audio tokens
with torch.no_grad():
output_ids = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode audio
global_tokens = inputs.get("global_token_ids_prompt")
output = self.processor.decode(
generated_ids=output_ids,
global_token_ids_prompt=global_tokens,
input_ids_len=input_ids_len
)
# Get audio data
audio_data = output["audio"]
sampling_rate = output["sampling_rate"]
# Ensure audio is valid
if audio_data is None or len(audio_data) == 0:
raise RuntimeError("Model generated empty audio output")
# Encode audio to base64
audio_base64 = self._encode_audio_base64(audio_data, sampling_rate)
logger.info(f"Generated audio: {len(audio_data)} samples at {sampling_rate}Hz")
return {
"audio": audio_base64,
"sampling_rate": sampling_rate
}
except Exception as e:
logger.error(f"Inference error: {str(e)}")
return {
"error": str(e),
"error_type": type(e).__name__
}