File size: 10,843 Bytes
a6afb46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
"""
Custom Handler for Hugging Face Inference Endpoints
Model: IbrahimSalah/Arabic-TTS-Spark
Repository: https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark
This handler provides Text-to-Speech inference for Arabic with:
- Voice cloning (with reference audio)
- Controllable TTS (with gender, pitch, speed parameters)
"""
import base64
import io
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import soundfile as sf
import torch
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Hugging Face Inference Endpoints handler for Arabic-TTS-Spark.
Supports two modes:
1. Voice Cloning: Provide reference audio to clone the voice
2. Controllable TTS: Specify gender, pitch, and speed parameters
"""
def __init__(self, path: str = ""):
"""
Initialize the handler by loading the model and processor.
Args:
path: Path to the model directory (provided by HF Inference Endpoints)
"""
from transformers import AutoModel, AutoProcessor
self.device = self._get_device()
logger.info(f"Initializing Arabic-TTS-Spark on device: {self.device}")
# Determine model path
model_path = path if path else "IbrahimSalah/Arabic-TTS-Spark"
logger.info(f"Loading model from: {model_path}")
# Load processor and model with trust_remote_code=True (required for custom classes)
self.processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32
)
# Move model to device and set to eval mode
self.model = self.model.to(self.device).eval()
# Link processor to model (required for voice cloning)
self.processor.link_model(self.model)
# Store default reference audio path
self.default_reference_path = Path(model_path) / "reference.wav"
if not self.default_reference_path.exists():
# Try to find it in the resolved path
self.default_reference_path = Path(path) / "reference.wav" if path else None
logger.info("Model loaded successfully")
def _get_device(self) -> torch.device:
"""Determine the best available device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def _decode_audio_base64(self, audio_base64: str) -> tuple:
"""
Decode base64 audio to numpy array.
Args:
audio_base64: Base64 encoded audio data
Returns:
Tuple of (audio_data, sample_rate)
"""
audio_bytes = base64.b64decode(audio_base64)
audio_buffer = io.BytesIO(audio_bytes)
audio_data, sample_rate = sf.read(audio_buffer)
return audio_data, sample_rate
def _encode_audio_base64(self, audio_data: np.ndarray, sample_rate: int) -> str:
"""
Encode audio numpy array to base64.
Args:
audio_data: Audio waveform as numpy array
sample_rate: Sample rate of the audio
Returns:
Base64 encoded audio string
"""
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
audio_buffer.seek(0)
return base64.b64encode(audio_buffer.read()).decode('utf-8')
def _validate_inputs(self, data: Dict[str, Any]) -> tuple:
"""
Validate and extract inputs from request data.
Args:
data: Request data dictionary
Returns:
Tuple of (text, parameters, mode)
"""
# Extract text input
text = data.get("inputs", "")
if not text:
raise ValueError("No input text provided. Use 'inputs' field.")
# Extract parameters
parameters = data.get("parameters", {})
# Determine mode
has_audio = "prompt_audio_base64" in parameters or "prompt_audio" in parameters
has_control = all(k in parameters for k in ["gender", "pitch", "speed"])
if has_audio:
mode = "voice_cloning"
elif has_control:
mode = "controllable"
else:
# Default to controllable with default parameters
mode = "controllable"
parameters.setdefault("gender", "male")
parameters.setdefault("pitch", "moderate")
parameters.setdefault("speed", "moderate")
return text, parameters, mode
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process inference request.
Args:
data: Request data with the following structure:
{
"inputs": "Arabic text with diacritics",
"parameters": {
# For voice cloning:
"prompt_audio_base64": "<base64-wav>", # or "prompt_audio"
"prompt_text": "reference transcript",
# For controllable TTS:
"gender": "male" or "female",
"pitch": "very_low", "low", "moderate", "high", "very_high",
"speed": "very_low", "low", "moderate", "high", "very_high",
# Generation parameters (optional):
"temperature": 0.8,
"max_new_tokens": 3000,
"top_p": 0.95,
"top_k": 50
}
}
Returns:
Dictionary with:
{
"audio": "<base64-encoded-wav>",
"sampling_rate": 16000
}
"""
try:
# Validate inputs
text, parameters, mode = self._validate_inputs(data)
logger.info(f"Processing request - Mode: {mode}, Text length: {len(text)}")
# Extract generation parameters
temperature = parameters.get("temperature", 0.8)
max_new_tokens = parameters.get("max_new_tokens", 3000)
top_p = parameters.get("top_p", 0.95)
top_k = parameters.get("top_k", 50)
# Prepare processor inputs based on mode
if mode == "voice_cloning":
# Handle voice cloning mode
audio_base64 = parameters.get("prompt_audio_base64") or parameters.get("prompt_audio")
prompt_text = parameters.get("prompt_text", "")
# Save audio to temporary file (processor expects file path)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
audio_data, _ = self._decode_audio_base64(audio_base64)
sf.write(tmp_file.name, audio_data, 16000)
tmp_audio_path = tmp_file.name
try:
# Process inputs for voice cloning
inputs = self.processor(
text=text,
prompt_speech_path=tmp_audio_path,
prompt_text=prompt_text if prompt_text else None,
return_tensors="pt"
)
finally:
# Clean up temporary file
os.unlink(tmp_audio_path)
else:
# Handle controllable TTS mode
gender = parameters.get("gender", "male")
pitch = parameters.get("pitch", "moderate")
speed = parameters.get("speed", "moderate")
# Validate parameter values
valid_genders = ["male", "female"]
valid_levels = ["very_low", "low", "moderate", "high", "very_high"]
if gender not in valid_genders:
raise ValueError(f"Invalid gender: {gender}. Must be one of {valid_genders}")
if pitch not in valid_levels:
raise ValueError(f"Invalid pitch: {pitch}. Must be one of {valid_levels}")
if speed not in valid_levels:
raise ValueError(f"Invalid speed: {speed}. Must be one of {valid_levels}")
# Process inputs for controllable TTS
inputs = self.processor(
text=text,
gender=gender,
pitch=pitch,
speed=speed,
return_tensors="pt"
)
# Move inputs to device
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Store input length for decoding
input_ids_len = inputs["input_ids"].shape[1]
# Generate audio tokens
with torch.no_grad():
output_ids = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode audio
global_tokens = inputs.get("global_token_ids_prompt")
output = self.processor.decode(
generated_ids=output_ids,
global_token_ids_prompt=global_tokens,
input_ids_len=input_ids_len
)
# Get audio data
audio_data = output["audio"]
sampling_rate = output["sampling_rate"]
# Ensure audio is valid
if audio_data is None or len(audio_data) == 0:
raise RuntimeError("Model generated empty audio output")
# Encode audio to base64
audio_base64 = self._encode_audio_base64(audio_data, sampling_rate)
logger.info(f"Generated audio: {len(audio_data)} samples at {sampling_rate}Hz")
return {
"audio": audio_base64,
"sampling_rate": sampling_rate
}
except Exception as e:
logger.error(f"Inference error: {str(e)}")
return {
"error": str(e),
"error_type": type(e).__name__
}
|