""" Audio processing service for VoiceAuth API. Handles Base64 decoding, format conversion, and audio preprocessing. """ import base64 import io from typing import TYPE_CHECKING import numpy as np from pydub import AudioSegment from app.config import get_settings from app.utils.constants import MP3_MAGIC_BYTES from app.utils.constants import TARGET_SAMPLE_RATE from app.utils.exceptions import AudioDecodeError from app.utils.exceptions import AudioDurationError from app.utils.exceptions import AudioFormatError from app.utils.exceptions import AudioProcessingError from app.utils.logger import get_logger if TYPE_CHECKING: import torch logger = get_logger(__name__) class AudioProcessor: """ Audio processing service for preparing audio for ML inference. Handles the complete pipeline from Base64-encoded MP3 to normalized numpy arrays suitable for Wav2Vec2. """ def __init__(self) -> None: """Initialize AudioProcessor with settings.""" self.settings = get_settings() self.target_sample_rate = TARGET_SAMPLE_RATE def decode_base64_audio(self, base64_string: str) -> bytes: """ Decode Base64 string to raw audio bytes. Args: base64_string: Base64-encoded audio data Returns: Raw audio bytes Raises: AudioDecodeError: If decoding fails """ try: # Handle potential padding issues base64_string = base64_string.strip() padding = 4 - len(base64_string) % 4 if padding != 4: base64_string += "=" * padding audio_bytes = base64.b64decode(base64_string) if len(audio_bytes) < 100: raise AudioDecodeError( "Decoded audio data is too small", details={"size_bytes": len(audio_bytes)}, ) logger.debug( "Decoded base64 audio", size_bytes=len(audio_bytes), ) return audio_bytes except AudioDecodeError: raise except Exception as e: raise AudioDecodeError( f"Failed to decode Base64 audio: {e}", details={"error": str(e)}, ) from e def validate_mp3_format(self, audio_bytes: bytes) -> bool: """ Validate that the audio bytes represent a valid MP3 file. Args: audio_bytes: Raw audio bytes Returns: True if valid MP3 Raises: AudioFormatError: If not a valid MP3 file """ # Check for MP3 magic bytes is_valid = any(audio_bytes.startswith(magic) for magic in MP3_MAGIC_BYTES) if not is_valid: raise AudioFormatError( "Invalid MP3 format: file does not have valid MP3 header", details={"header_bytes": audio_bytes[:10].hex()}, ) return True def convert_mp3_to_wav_array(self, mp3_bytes: bytes) -> np.ndarray: """ Convert MP3 bytes to normalized WAV numpy array. Args: mp3_bytes: Raw MP3 audio bytes Returns: Normalized numpy array of audio samples Raises: AudioProcessingError: If conversion fails """ try: # Load MP3 using pydub audio_buffer = io.BytesIO(mp3_bytes) audio_segment = AudioSegment.from_mp3(audio_buffer) # Convert to mono if stereo if audio_segment.channels > 1: audio_segment = audio_segment.set_channels(1) # Resample to target sample rate if audio_segment.frame_rate != self.target_sample_rate: audio_segment = audio_segment.set_frame_rate(self.target_sample_rate) # Convert to numpy array samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) # Normalize to [-1, 1] range samples = samples / 32768.0 # 16-bit audio normalization logger.debug( "Converted MP3 to WAV array", original_channels=audio_segment.channels, sample_rate=self.target_sample_rate, num_samples=len(samples), ) return samples except Exception as e: raise AudioProcessingError( f"Failed to convert MP3 to WAV: {e}", details={"error": str(e)}, ) from e def validate_audio_duration( self, audio_array: np.ndarray, sample_rate: int | None = None, ) -> float: """ Validate audio duration is within allowed bounds. Args: audio_array: Numpy array of audio samples sample_rate: Sample rate (uses target_sample_rate if not provided) Returns: Duration in seconds Raises: AudioDurationError: If duration is out of bounds """ if sample_rate is None: sample_rate = self.target_sample_rate duration = len(audio_array) / sample_rate if duration < self.settings.MIN_AUDIO_DURATION: raise AudioDurationError( f"Audio too short: {duration:.2f}s (minimum: {self.settings.MIN_AUDIO_DURATION}s)", duration=duration, min_duration=self.settings.MIN_AUDIO_DURATION, ) if duration > self.settings.MAX_AUDIO_DURATION: raise AudioDurationError( f"Audio too long: {duration:.2f}s (maximum: {self.settings.MAX_AUDIO_DURATION}s)", duration=duration, max_duration=self.settings.MAX_AUDIO_DURATION, ) logger.debug("Audio duration validated", duration_seconds=round(duration, 2)) return duration def normalize_audio(self, audio_array: np.ndarray) -> np.ndarray: """ Normalize audio amplitude to [-1, 1] range. Applies peak normalization to maximize dynamic range. Args: audio_array: Input audio array Returns: Normalized audio array """ # Avoid division by zero for silent audio max_amplitude = np.abs(audio_array).max() if max_amplitude < 1e-8: logger.warning("Audio appears to be silent or near-silent") return audio_array normalized = audio_array / max_amplitude return normalized def extract_audio_metadata( self, audio_array: np.ndarray, sample_rate: int | None = None, ) -> dict: """ Extract metadata from audio for explainability. Args: audio_array: Numpy array of audio samples sample_rate: Sample rate Returns: Dictionary of audio metadata """ if sample_rate is None: sample_rate = self.target_sample_rate duration = len(audio_array) / sample_rate # Calculate RMS energy rms_energy = float(np.sqrt(np.mean(audio_array**2))) # Calculate zero crossing rate zero_crossings = np.sum(np.abs(np.diff(np.sign(audio_array)))) / 2 zcr = float(zero_crossings / len(audio_array)) # Calculate peak amplitude peak_amplitude = float(np.abs(audio_array).max()) return { "duration_seconds": round(duration, 3), "num_samples": len(audio_array), "sample_rate": sample_rate, "rms_energy": round(rms_energy, 6), "zero_crossing_rate": round(zcr, 6), "peak_amplitude": round(peak_amplitude, 6), } def process_audio(self, audio_base64: str) -> tuple[np.ndarray, dict]: """ Complete audio processing pipeline. Takes Base64-encoded MP3 and returns normalized audio array with metadata. Args: audio_base64: Base64-encoded MP3 audio Returns: Tuple of (normalized audio array, metadata dict) Raises: AudioDecodeError: If Base64 decoding fails AudioFormatError: If not valid MP3 AudioDurationError: If duration out of bounds AudioProcessingError: If processing fails """ logger.info("Starting audio processing pipeline") # Decode Base64 audio_bytes = self.decode_base64_audio(audio_base64) # Validate MP3 format self.validate_mp3_format(audio_bytes) # Convert to WAV array audio_array = self.convert_mp3_to_wav_array(audio_bytes) # Validate duration self.validate_audio_duration(audio_array) # Normalize normalized_audio = self.normalize_audio(audio_array) # Extract metadata metadata = self.extract_audio_metadata(normalized_audio) logger.info( "Audio processing complete", duration=metadata["duration_seconds"], samples=metadata["num_samples"], ) return normalized_audio, metadata