""" Custom Handler for Hugging Face Inference Endpoints Model: IbrahimSalah/Arabic-TTS-Spark Repository: https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark This handler provides Text-to-Speech inference for Arabic with: - Voice cloning (with reference audio) - Controllable TTS (with gender, pitch, speed parameters) """ import base64 import io import logging import os import tempfile from pathlib import Path from typing import Any, Dict, Optional import numpy as np import soundfile as sf import torch logger = logging.getLogger(__name__) class EndpointHandler: """ Hugging Face Inference Endpoints handler for Arabic-TTS-Spark. Supports two modes: 1. Voice Cloning: Provide reference audio to clone the voice 2. Controllable TTS: Specify gender, pitch, and speed parameters """ def __init__(self, path: str = ""): """ Initialize the handler by loading the model and processor. Args: path: Path to the model directory (provided by HF Inference Endpoints) """ from transformers import AutoModel, AutoProcessor self.device = self._get_device() logger.info(f"Initializing Arabic-TTS-Spark on device: {self.device}") # Determine model path model_path = path if path else "IbrahimSalah/Arabic-TTS-Spark" logger.info(f"Loading model from: {model_path}") # Load processor and model with trust_remote_code=True (required for custom classes) self.processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True ) self.model = AutoModel.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32 ) # Move model to device and set to eval mode self.model = self.model.to(self.device).eval() # Link processor to model (required for voice cloning) self.processor.link_model(self.model) # Store default reference audio path self.default_reference_path = Path(model_path) / "reference.wav" if not self.default_reference_path.exists(): # Try to find it in the resolved path self.default_reference_path = Path(path) / "reference.wav" if path else None logger.info("Model loaded successfully") def _get_device(self) -> torch.device: """Determine the best available device.""" if torch.cuda.is_available(): return torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def _decode_audio_base64(self, audio_base64: str) -> tuple: """ Decode base64 audio to numpy array. Args: audio_base64: Base64 encoded audio data Returns: Tuple of (audio_data, sample_rate) """ audio_bytes = base64.b64decode(audio_base64) audio_buffer = io.BytesIO(audio_bytes) audio_data, sample_rate = sf.read(audio_buffer) return audio_data, sample_rate def _encode_audio_base64(self, audio_data: np.ndarray, sample_rate: int) -> str: """ Encode audio numpy array to base64. Args: audio_data: Audio waveform as numpy array sample_rate: Sample rate of the audio Returns: Base64 encoded audio string """ audio_buffer = io.BytesIO() sf.write(audio_buffer, audio_data, sample_rate, format='WAV') audio_buffer.seek(0) return base64.b64encode(audio_buffer.read()).decode('utf-8') def _validate_inputs(self, data: Dict[str, Any]) -> tuple: """ Validate and extract inputs from request data. Args: data: Request data dictionary Returns: Tuple of (text, parameters, mode) """ # Extract text input text = data.get("inputs", "") if not text: raise ValueError("No input text provided. Use 'inputs' field.") # Extract parameters parameters = data.get("parameters", {}) # Determine mode has_audio = "prompt_audio_base64" in parameters or "prompt_audio" in parameters has_control = all(k in parameters for k in ["gender", "pitch", "speed"]) if has_audio: mode = "voice_cloning" elif has_control: mode = "controllable" else: # Default to controllable with default parameters mode = "controllable" parameters.setdefault("gender", "male") parameters.setdefault("pitch", "moderate") parameters.setdefault("speed", "moderate") return text, parameters, mode def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process inference request. Args: data: Request data with the following structure: { "inputs": "Arabic text with diacritics", "parameters": { # For voice cloning: "prompt_audio_base64": "", # or "prompt_audio" "prompt_text": "reference transcript", # For controllable TTS: "gender": "male" or "female", "pitch": "very_low", "low", "moderate", "high", "very_high", "speed": "very_low", "low", "moderate", "high", "very_high", # Generation parameters (optional): "temperature": 0.8, "max_new_tokens": 3000, "top_p": 0.95, "top_k": 50 } } Returns: Dictionary with: { "audio": "", "sampling_rate": 16000 } """ try: # Validate inputs text, parameters, mode = self._validate_inputs(data) logger.info(f"Processing request - Mode: {mode}, Text length: {len(text)}") # Extract generation parameters temperature = parameters.get("temperature", 0.8) max_new_tokens = parameters.get("max_new_tokens", 3000) top_p = parameters.get("top_p", 0.95) top_k = parameters.get("top_k", 50) # Prepare processor inputs based on mode if mode == "voice_cloning": # Handle voice cloning mode audio_base64 = parameters.get("prompt_audio_base64") or parameters.get("prompt_audio") prompt_text = parameters.get("prompt_text", "") # Save audio to temporary file (processor expects file path) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: audio_data, _ = self._decode_audio_base64(audio_base64) sf.write(tmp_file.name, audio_data, 16000) tmp_audio_path = tmp_file.name try: # Process inputs for voice cloning inputs = self.processor( text=text, prompt_speech_path=tmp_audio_path, prompt_text=prompt_text if prompt_text else None, return_tensors="pt" ) finally: # Clean up temporary file os.unlink(tmp_audio_path) else: # Handle controllable TTS mode gender = parameters.get("gender", "male") pitch = parameters.get("pitch", "moderate") speed = parameters.get("speed", "moderate") # Validate parameter values valid_genders = ["male", "female"] valid_levels = ["very_low", "low", "moderate", "high", "very_high"] if gender not in valid_genders: raise ValueError(f"Invalid gender: {gender}. Must be one of {valid_genders}") if pitch not in valid_levels: raise ValueError(f"Invalid pitch: {pitch}. Must be one of {valid_levels}") if speed not in valid_levels: raise ValueError(f"Invalid speed: {speed}. Must be one of {valid_levels}") # Process inputs for controllable TTS inputs = self.processor( text=text, gender=gender, pitch=pitch, speed=speed, return_tensors="pt" ) # Move inputs to device inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Store input length for decoding input_ids_len = inputs["input_ids"].shape[1] # Generate audio tokens with torch.no_grad(): output_ids = self.model.generate( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id, eos_token_id=self.processor.tokenizer.eos_token_id, ) # Decode audio global_tokens = inputs.get("global_token_ids_prompt") output = self.processor.decode( generated_ids=output_ids, global_token_ids_prompt=global_tokens, input_ids_len=input_ids_len ) # Get audio data audio_data = output["audio"] sampling_rate = output["sampling_rate"] # Ensure audio is valid if audio_data is None or len(audio_data) == 0: raise RuntimeError("Model generated empty audio output") # Encode audio to base64 audio_base64 = self._encode_audio_base64(audio_data, sampling_rate) logger.info(f"Generated audio: {len(audio_data)} samples at {sampling_rate}Hz") return { "audio": audio_base64, "sampling_rate": sampling_rate } except Exception as e: logger.error(f"Inference error: {str(e)}") return { "error": str(e), "error_type": type(e).__name__ }