| | """ |
| | Custom Inference Handler for StutteredSpeechASR Model |
| | Handles audio input and returns transcriptions for stuttered speech. |
| | """ |
| |
|
| | import torch |
| | import librosa |
| | import numpy as np |
| | import base64 |
| | import io |
| | import logging |
| | from typing import Dict, Any |
| | from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | Custom handler for StutteredSpeechASR inference endpoint. |
| | |
| | This handler processes audio inputs and returns transcriptions |
| | using the fine-tuned Whisper model for stuttered Mandarin speech. |
| | """ |
| | |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the handler by loading the model and processor. |
| | |
| | Args: |
| | path: Path to the model directory (provided by Inference Endpoints) |
| | """ |
| | logger.info("Initializing StutteredSpeechASR handler...") |
| | |
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| | |
| | logger.info(f"Using device: {self.device}") |
| | logger.info(f"Using dtype: {self.torch_dtype}") |
| | |
| | |
| | try: |
| | self.model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| | path, |
| | torch_dtype=self.torch_dtype |
| | ) |
| | self.processor = AutoProcessor.from_pretrained(path) |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | logger.info("Model and processor loaded successfully!") |
| | except Exception as e: |
| | logger.error(f"Error loading model: {e}") |
| | raise |
| | |
| | def _load_audio_from_bytes(self, audio_bytes: bytes) -> np.ndarray: |
| | """ |
| | Load audio from bytes and resample to 16kHz. |
| | |
| | Args: |
| | audio_bytes: Raw audio bytes |
| | |
| | Returns: |
| | Audio waveform as numpy array |
| | """ |
| | try: |
| | |
| | audio_buffer = io.BytesIO(audio_bytes) |
| | waveform, _ = librosa.load(audio_buffer, sr=16000, mono=True) |
| | return waveform |
| | except Exception as e: |
| | logger.error(f"Error loading audio from bytes: {e}") |
| | raise ValueError(f"Failed to load audio: {e}") |
| | |
| | def _load_audio_from_base64(self, base64_string: str) -> np.ndarray: |
| | """ |
| | Load audio from base64-encoded string. |
| | |
| | Args: |
| | base64_string: Base64-encoded audio data |
| | |
| | Returns: |
| | Audio waveform as numpy array |
| | """ |
| | try: |
| | |
| | audio_bytes = base64.b64decode(base64_string) |
| | return self._load_audio_from_bytes(audio_bytes) |
| | except Exception as e: |
| | logger.error(f"Error decoding base64 audio: {e}") |
| | raise ValueError(f"Failed to decode base64 audio: {e}") |
| | |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process incoming requests and return transcriptions. |
| | |
| | Expected input formats: |
| | 1. {"inputs": "base64_encoded_audio_string"} |
| | 2. {"inputs": {"audio": "base64_encoded_audio_string"}} |
| | 3. Binary audio data in request body |
| | |
| | Args: |
| | data: Input data dictionary |
| | |
| | Returns: |
| | Dictionary containing transcription and metadata |
| | """ |
| | try: |
| | logger.info("Processing inference request...") |
| | |
| | |
| | waveform = None |
| | |
| | if isinstance(data, dict): |
| | |
| | if "inputs" in data: |
| | inputs = data["inputs"] |
| | |
| | if isinstance(inputs, str): |
| | |
| | waveform = self._load_audio_from_base64(inputs) |
| | |
| | elif isinstance(inputs, dict): |
| | |
| | if "audio" in inputs: |
| | waveform = self._load_audio_from_base64(inputs["audio"]) |
| | else: |
| | raise ValueError("Missing 'audio' field in inputs dictionary") |
| | |
| | elif isinstance(inputs, bytes): |
| | |
| | waveform = self._load_audio_from_bytes(inputs) |
| | |
| | else: |
| | raise ValueError(f"Unsupported input type: {type(inputs)}") |
| | |
| | |
| | elif "audio" in data: |
| | audio_data = data["audio"] |
| | if isinstance(audio_data, str): |
| | waveform = self._load_audio_from_base64(audio_data) |
| | elif isinstance(audio_data, bytes): |
| | waveform = self._load_audio_from_bytes(audio_data) |
| | |
| | else: |
| | raise ValueError("No valid audio data found in request. Expected 'inputs' or 'audio' field.") |
| | |
| | elif isinstance(data, (bytes, bytearray)): |
| | |
| | waveform = self._load_audio_from_bytes(bytes(data)) |
| | |
| | else: |
| | raise ValueError(f"Unsupported data type: {type(data)}") |
| | |
| | if waveform is None: |
| | raise ValueError("Failed to extract audio from request") |
| | |
| | logger.info(f"Audio loaded: {len(waveform)} samples at 16kHz") |
| | |
| | |
| | input_features = self.processor( |
| | waveform, |
| | sampling_rate=16000, |
| | return_tensors="pt" |
| | ).input_features |
| | |
| | |
| | input_features = input_features.to(self.device, dtype=self.torch_dtype) |
| | |
| | |
| | with torch.no_grad(): |
| | predicted_ids = self.model.generate(input_features) |
| | |
| | |
| | transcription = self.processor.batch_decode( |
| | predicted_ids, |
| | skip_special_tokens=True |
| | )[0] |
| | |
| | logger.info(f"Transcription complete: {transcription[:100]}...") |
| | |
| | |
| | return { |
| | "transcription": transcription.strip(), |
| | "audio_duration_seconds": float(len(waveform) / 16000), |
| | "model": "AImpower/StutteredSpeechASR" |
| | } |
| | |
| | except Exception as e: |
| | logger.error(f"Error during inference: {e}", exc_info=True) |
| | return { |
| | "error": str(e), |
| | "transcription": None |
| | } |
| |
|