|
|
""" |
|
|
Custom Handler for Hugging Face Inference Endpoints |
|
|
Model: IbrahimSalah/Arabic-TTS-Spark |
|
|
Repository: https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark |
|
|
|
|
|
This handler provides Text-to-Speech inference for Arabic with: |
|
|
- Voice cloning (with reference audio) |
|
|
- Controllable TTS (with gender, pitch, speed parameters) |
|
|
""" |
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import logging |
|
|
import os |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Hugging Face Inference Endpoints handler for Arabic-TTS-Spark. |
|
|
|
|
|
Supports two modes: |
|
|
1. Voice Cloning: Provide reference audio to clone the voice |
|
|
2. Controllable TTS: Specify gender, pitch, and speed parameters |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize the handler by loading the model and processor. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory (provided by HF Inference Endpoints) |
|
|
""" |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
|
|
|
self.device = self._get_device() |
|
|
logger.info(f"Initializing Arabic-TTS-Spark on device: {self.device}") |
|
|
|
|
|
|
|
|
model_path = path if path else "IbrahimSalah/Arabic-TTS-Spark" |
|
|
logger.info(f"Loading model from: {model_path}") |
|
|
|
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
self.model = self.model.to(self.device).eval() |
|
|
|
|
|
|
|
|
self.processor.link_model(self.model) |
|
|
|
|
|
|
|
|
self.default_reference_path = Path(model_path) / "reference.wav" |
|
|
if not self.default_reference_path.exists(): |
|
|
|
|
|
self.default_reference_path = Path(path) / "reference.wav" if path else None |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
def _get_device(self) -> torch.device: |
|
|
"""Determine the best available device.""" |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
return torch.device("mps") |
|
|
return torch.device("cpu") |
|
|
|
|
|
def _decode_audio_base64(self, audio_base64: str) -> tuple: |
|
|
""" |
|
|
Decode base64 audio to numpy array. |
|
|
|
|
|
Args: |
|
|
audio_base64: Base64 encoded audio data |
|
|
|
|
|
Returns: |
|
|
Tuple of (audio_data, sample_rate) |
|
|
""" |
|
|
audio_bytes = base64.b64decode(audio_base64) |
|
|
audio_buffer = io.BytesIO(audio_bytes) |
|
|
audio_data, sample_rate = sf.read(audio_buffer) |
|
|
return audio_data, sample_rate |
|
|
|
|
|
def _encode_audio_base64(self, audio_data: np.ndarray, sample_rate: int) -> str: |
|
|
""" |
|
|
Encode audio numpy array to base64. |
|
|
|
|
|
Args: |
|
|
audio_data: Audio waveform as numpy array |
|
|
sample_rate: Sample rate of the audio |
|
|
|
|
|
Returns: |
|
|
Base64 encoded audio string |
|
|
""" |
|
|
audio_buffer = io.BytesIO() |
|
|
sf.write(audio_buffer, audio_data, sample_rate, format='WAV') |
|
|
audio_buffer.seek(0) |
|
|
return base64.b64encode(audio_buffer.read()).decode('utf-8') |
|
|
|
|
|
def _validate_inputs(self, data: Dict[str, Any]) -> tuple: |
|
|
""" |
|
|
Validate and extract inputs from request data. |
|
|
|
|
|
Args: |
|
|
data: Request data dictionary |
|
|
|
|
|
Returns: |
|
|
Tuple of (text, parameters, mode) |
|
|
""" |
|
|
|
|
|
text = data.get("inputs", "") |
|
|
if not text: |
|
|
raise ValueError("No input text provided. Use 'inputs' field.") |
|
|
|
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
has_audio = "prompt_audio_base64" in parameters or "prompt_audio" in parameters |
|
|
has_control = all(k in parameters for k in ["gender", "pitch", "speed"]) |
|
|
|
|
|
if has_audio: |
|
|
mode = "voice_cloning" |
|
|
elif has_control: |
|
|
mode = "controllable" |
|
|
else: |
|
|
|
|
|
mode = "controllable" |
|
|
parameters.setdefault("gender", "male") |
|
|
parameters.setdefault("pitch", "moderate") |
|
|
parameters.setdefault("speed", "moderate") |
|
|
|
|
|
return text, parameters, mode |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process inference request. |
|
|
|
|
|
Args: |
|
|
data: Request data with the following structure: |
|
|
{ |
|
|
"inputs": "Arabic text with diacritics", |
|
|
"parameters": { |
|
|
# For voice cloning: |
|
|
"prompt_audio_base64": "<base64-wav>", # or "prompt_audio" |
|
|
"prompt_text": "reference transcript", |
|
|
|
|
|
# For controllable TTS: |
|
|
"gender": "male" or "female", |
|
|
"pitch": "very_low", "low", "moderate", "high", "very_high", |
|
|
"speed": "very_low", "low", "moderate", "high", "very_high", |
|
|
|
|
|
# Generation parameters (optional): |
|
|
"temperature": 0.8, |
|
|
"max_new_tokens": 3000, |
|
|
"top_p": 0.95, |
|
|
"top_k": 50 |
|
|
} |
|
|
} |
|
|
|
|
|
Returns: |
|
|
Dictionary with: |
|
|
{ |
|
|
"audio": "<base64-encoded-wav>", |
|
|
"sampling_rate": 16000 |
|
|
} |
|
|
""" |
|
|
try: |
|
|
|
|
|
text, parameters, mode = self._validate_inputs(data) |
|
|
logger.info(f"Processing request - Mode: {mode}, Text length: {len(text)}") |
|
|
|
|
|
|
|
|
temperature = parameters.get("temperature", 0.8) |
|
|
max_new_tokens = parameters.get("max_new_tokens", 3000) |
|
|
top_p = parameters.get("top_p", 0.95) |
|
|
top_k = parameters.get("top_k", 50) |
|
|
|
|
|
|
|
|
if mode == "voice_cloning": |
|
|
|
|
|
audio_base64 = parameters.get("prompt_audio_base64") or parameters.get("prompt_audio") |
|
|
prompt_text = parameters.get("prompt_text", "") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
|
|
audio_data, _ = self._decode_audio_base64(audio_base64) |
|
|
sf.write(tmp_file.name, audio_data, 16000) |
|
|
tmp_audio_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = self.processor( |
|
|
text=text, |
|
|
prompt_speech_path=tmp_audio_path, |
|
|
prompt_text=prompt_text if prompt_text else None, |
|
|
return_tensors="pt" |
|
|
) |
|
|
finally: |
|
|
|
|
|
os.unlink(tmp_audio_path) |
|
|
else: |
|
|
|
|
|
gender = parameters.get("gender", "male") |
|
|
pitch = parameters.get("pitch", "moderate") |
|
|
speed = parameters.get("speed", "moderate") |
|
|
|
|
|
|
|
|
valid_genders = ["male", "female"] |
|
|
valid_levels = ["very_low", "low", "moderate", "high", "very_high"] |
|
|
|
|
|
if gender not in valid_genders: |
|
|
raise ValueError(f"Invalid gender: {gender}. Must be one of {valid_genders}") |
|
|
if pitch not in valid_levels: |
|
|
raise ValueError(f"Invalid pitch: {pitch}. Must be one of {valid_levels}") |
|
|
if speed not in valid_levels: |
|
|
raise ValueError(f"Invalid speed: {speed}. Must be one of {valid_levels}") |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
text=text, |
|
|
gender=gender, |
|
|
pitch=pitch, |
|
|
speed=speed, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
|
|
for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
input_ids_len = inputs["input_ids"].shape[1] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = self.model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
attention_mask=inputs.get("attention_mask"), |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
do_sample=True, |
|
|
pad_token_id=self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id, |
|
|
eos_token_id=self.processor.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
global_tokens = inputs.get("global_token_ids_prompt") |
|
|
output = self.processor.decode( |
|
|
generated_ids=output_ids, |
|
|
global_token_ids_prompt=global_tokens, |
|
|
input_ids_len=input_ids_len |
|
|
) |
|
|
|
|
|
|
|
|
audio_data = output["audio"] |
|
|
sampling_rate = output["sampling_rate"] |
|
|
|
|
|
|
|
|
if audio_data is None or len(audio_data) == 0: |
|
|
raise RuntimeError("Model generated empty audio output") |
|
|
|
|
|
|
|
|
audio_base64 = self._encode_audio_base64(audio_data, sampling_rate) |
|
|
|
|
|
logger.info(f"Generated audio: {len(audio_data)} samples at {sampling_rate}Hz") |
|
|
|
|
|
return { |
|
|
"audio": audio_base64, |
|
|
"sampling_rate": sampling_rate |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference error: {str(e)}") |
|
|
return { |
|
|
"error": str(e), |
|
|
"error_type": type(e).__name__ |
|
|
} |
|
|
|