File size: 10,843 Bytes
a6afb46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
Custom Handler for Hugging Face Inference Endpoints
Model: IbrahimSalah/Arabic-TTS-Spark
Repository: https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark

This handler provides Text-to-Speech inference for Arabic with:
- Voice cloning (with reference audio)
- Controllable TTS (with gender, pitch, speed parameters)
"""

import base64
import io
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional

import numpy as np
import soundfile as sf
import torch

logger = logging.getLogger(__name__)


class EndpointHandler:
    """
    Hugging Face Inference Endpoints handler for Arabic-TTS-Spark.

    Supports two modes:
    1. Voice Cloning: Provide reference audio to clone the voice
    2. Controllable TTS: Specify gender, pitch, and speed parameters
    """

    def __init__(self, path: str = ""):
        """
        Initialize the handler by loading the model and processor.

        Args:
            path: Path to the model directory (provided by HF Inference Endpoints)
        """
        from transformers import AutoModel, AutoProcessor

        self.device = self._get_device()
        logger.info(f"Initializing Arabic-TTS-Spark on device: {self.device}")

        # Determine model path
        model_path = path if path else "IbrahimSalah/Arabic-TTS-Spark"
        logger.info(f"Loading model from: {model_path}")

        # Load processor and model with trust_remote_code=True (required for custom classes)
        self.processor = AutoProcessor.from_pretrained(
            model_path,
            trust_remote_code=True
        )

        self.model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32
        )

        # Move model to device and set to eval mode
        self.model = self.model.to(self.device).eval()

        # Link processor to model (required for voice cloning)
        self.processor.link_model(self.model)

        # Store default reference audio path
        self.default_reference_path = Path(model_path) / "reference.wav"
        if not self.default_reference_path.exists():
            # Try to find it in the resolved path
            self.default_reference_path = Path(path) / "reference.wav" if path else None

        logger.info("Model loaded successfully")

    def _get_device(self) -> torch.device:
        """Determine the best available device."""
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")

    def _decode_audio_base64(self, audio_base64: str) -> tuple:
        """
        Decode base64 audio to numpy array.

        Args:
            audio_base64: Base64 encoded audio data

        Returns:
            Tuple of (audio_data, sample_rate)
        """
        audio_bytes = base64.b64decode(audio_base64)
        audio_buffer = io.BytesIO(audio_bytes)
        audio_data, sample_rate = sf.read(audio_buffer)
        return audio_data, sample_rate

    def _encode_audio_base64(self, audio_data: np.ndarray, sample_rate: int) -> str:
        """
        Encode audio numpy array to base64.

        Args:
            audio_data: Audio waveform as numpy array
            sample_rate: Sample rate of the audio

        Returns:
            Base64 encoded audio string
        """
        audio_buffer = io.BytesIO()
        sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
        audio_buffer.seek(0)
        return base64.b64encode(audio_buffer.read()).decode('utf-8')

    def _validate_inputs(self, data: Dict[str, Any]) -> tuple:
        """
        Validate and extract inputs from request data.

        Args:
            data: Request data dictionary

        Returns:
            Tuple of (text, parameters, mode)
        """
        # Extract text input
        text = data.get("inputs", "")
        if not text:
            raise ValueError("No input text provided. Use 'inputs' field.")

        # Extract parameters
        parameters = data.get("parameters", {})

        # Determine mode
        has_audio = "prompt_audio_base64" in parameters or "prompt_audio" in parameters
        has_control = all(k in parameters for k in ["gender", "pitch", "speed"])

        if has_audio:
            mode = "voice_cloning"
        elif has_control:
            mode = "controllable"
        else:
            # Default to controllable with default parameters
            mode = "controllable"
            parameters.setdefault("gender", "male")
            parameters.setdefault("pitch", "moderate")
            parameters.setdefault("speed", "moderate")

        return text, parameters, mode

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process inference request.

        Args:
            data: Request data with the following structure:
                {
                    "inputs": "Arabic text with diacritics",
                    "parameters": {
                        # For voice cloning:
                        "prompt_audio_base64": "<base64-wav>",  # or "prompt_audio"
                        "prompt_text": "reference transcript",

                        # For controllable TTS:
                        "gender": "male" or "female",
                        "pitch": "very_low", "low", "moderate", "high", "very_high",
                        "speed": "very_low", "low", "moderate", "high", "very_high",

                        # Generation parameters (optional):
                        "temperature": 0.8,
                        "max_new_tokens": 3000,
                        "top_p": 0.95,
                        "top_k": 50
                    }
                }

        Returns:
            Dictionary with:
                {
                    "audio": "<base64-encoded-wav>",
                    "sampling_rate": 16000
                }
        """
        try:
            # Validate inputs
            text, parameters, mode = self._validate_inputs(data)
            logger.info(f"Processing request - Mode: {mode}, Text length: {len(text)}")

            # Extract generation parameters
            temperature = parameters.get("temperature", 0.8)
            max_new_tokens = parameters.get("max_new_tokens", 3000)
            top_p = parameters.get("top_p", 0.95)
            top_k = parameters.get("top_k", 50)

            # Prepare processor inputs based on mode
            if mode == "voice_cloning":
                # Handle voice cloning mode
                audio_base64 = parameters.get("prompt_audio_base64") or parameters.get("prompt_audio")
                prompt_text = parameters.get("prompt_text", "")

                # Save audio to temporary file (processor expects file path)
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                    audio_data, _ = self._decode_audio_base64(audio_base64)
                    sf.write(tmp_file.name, audio_data, 16000)
                    tmp_audio_path = tmp_file.name

                try:
                    # Process inputs for voice cloning
                    inputs = self.processor(
                        text=text,
                        prompt_speech_path=tmp_audio_path,
                        prompt_text=prompt_text if prompt_text else None,
                        return_tensors="pt"
                    )
                finally:
                    # Clean up temporary file
                    os.unlink(tmp_audio_path)
            else:
                # Handle controllable TTS mode
                gender = parameters.get("gender", "male")
                pitch = parameters.get("pitch", "moderate")
                speed = parameters.get("speed", "moderate")

                # Validate parameter values
                valid_genders = ["male", "female"]
                valid_levels = ["very_low", "low", "moderate", "high", "very_high"]

                if gender not in valid_genders:
                    raise ValueError(f"Invalid gender: {gender}. Must be one of {valid_genders}")
                if pitch not in valid_levels:
                    raise ValueError(f"Invalid pitch: {pitch}. Must be one of {valid_levels}")
                if speed not in valid_levels:
                    raise ValueError(f"Invalid speed: {speed}. Must be one of {valid_levels}")

                # Process inputs for controllable TTS
                inputs = self.processor(
                    text=text,
                    gender=gender,
                    pitch=pitch,
                    speed=speed,
                    return_tensors="pt"
                )

            # Move inputs to device
            inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                     for k, v in inputs.items()}

            # Store input length for decoding
            input_ids_len = inputs["input_ids"].shape[1]

            # Generate audio tokens
            with torch.no_grad():
                output_ids = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs.get("attention_mask"),
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    do_sample=True,
                    pad_token_id=self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id,
                    eos_token_id=self.processor.tokenizer.eos_token_id,
                )

            # Decode audio
            global_tokens = inputs.get("global_token_ids_prompt")
            output = self.processor.decode(
                generated_ids=output_ids,
                global_token_ids_prompt=global_tokens,
                input_ids_len=input_ids_len
            )

            # Get audio data
            audio_data = output["audio"]
            sampling_rate = output["sampling_rate"]

            # Ensure audio is valid
            if audio_data is None or len(audio_data) == 0:
                raise RuntimeError("Model generated empty audio output")

            # Encode audio to base64
            audio_base64 = self._encode_audio_base64(audio_data, sampling_rate)

            logger.info(f"Generated audio: {len(audio_data)} samples at {sampling_rate}Hz")

            return {
                "audio": audio_base64,
                "sampling_rate": sampling_rate
            }

        except Exception as e:
            logger.error(f"Inference error: {str(e)}")
            return {
                "error": str(e),
                "error_type": type(e).__name__
            }