import tensorflow as tf import numpy as np import os import io import tempfile import logging import time os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" from keras.models import load_model from keras.layers import Layer # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() ] ) logger = logging.getLogger('audio_inference') class WavToMelLayer(Layer): def __init__(self, sample_rate=16000, n_mels=128, fft_size=1024, hop_size=512, **kwargs): super(WavToMelLayer, self).__init__(**kwargs) self.sample_rate = sample_rate self.n_mels = n_mels self.fft_size = fft_size self.hop_size = hop_size def call(self, inputs): def process_audio(input_path): logger.debug(f"Processing audio file: {input_path}") try: audio = tf.io.read_file(input_path) audio, sr = tf.audio.decode_wav(audio, desired_channels=1) logger.debug(f"Decoded WAV file with sample rate: {sr}, shape: {audio.shape}") audio = tf.squeeze(audio, axis=-1) stft = tf.signal.stft(audio, frame_length=self.fft_size, frame_step=self.hop_size) logger.debug(f"STFT shape: {stft.shape}") spectrogram = tf.abs(stft) ** 2 mel_weights = tf.signal.linear_to_mel_weight_matrix( self.n_mels, self.fft_size // 2 + 1, self.sample_rate, 20.0, 4000.0 ) mel_spectrogram = tf.tensordot(spectrogram, mel_weights, axes=1) mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6) logger.debug(f"Mel spectrogram shape: {mel_spectrogram.shape}") mel_spectrogram = tf.image.resize(mel_spectrogram[..., tf.newaxis], [128, 128]) mel_spectrogram = tf.image.grayscale_to_rgb(mel_spectrogram) logger.debug(f"Final mel spectrogram shape: {mel_spectrogram.shape}") return mel_spectrogram except Exception as e: logger.error(f"Error in process_audio: {str(e)}") raise return tf.map_fn(process_audio, inputs, dtype=tf.float32) def get_config(self): config = super(WavToMelLayer, self).get_config() config.update({ "sample_rate": self.sample_rate, "n_mels": self.n_mels, "fft_size": self.fft_size, "hop_size": self.hop_size }) return config class EndpointHandler: def __init__(self, model_dir): logger.info("Initializing EndpointHandler") if model_dir is None: model_dir = os.path.dirname(os.path.abspath(__file__)) logger.info(f"Model directory not provided, using current directory: {model_dir}") else: logger.info(f"Using provided model directory: {model_dir}") model_path = os.path.join(model_dir, "model/bestModel.keras") logger.info(f"Loading model from: {model_path}") try: self.model = load_model(model_path, custom_objects={"WavToMelLayer": WavToMelLayer}) logger.info(f"Model loaded successfully: {self.model.summary()}") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise def __call__(self, requests): start_time = time.time() logger.info("Processing inference request") temp_dir = None temp_wav_path = None input_yeah = requests['inputs'] try: temp_dir = tempfile.mkdtemp() temp_wav_path = os.path.join(temp_dir, "wav_input.wav") logger.info(f"Created temporary directory: {temp_dir}") logger.info(requests) if not isinstance(input_yeah, bytes): logger.error(f"Expected bytes, got {type(input_yeah)}") return [{"error": f"Invalid input type: {type(input_yeah)}, expected bytes"}] logger.debug(f"Writing {len(input_yeah)} bytes to temporary file: {temp_wav_path}") with open(temp_wav_path, "wb") as f: f.write(input_yeah) if not os.path.exists(temp_wav_path): logger.error(f"Failed to create temporary WAV file: {temp_wav_path}") return [{"error": "Failed to create temporary WAV file"}] logger.debug(f"File size: {os.path.getsize(temp_wav_path)} bytes") inputs = tf.constant([temp_wav_path]) logger.info("Running model prediction") predictions = self.model.predict(inputs) logger.debug(f"Raw predictions: {predictions}") results = [] for i, prediction in enumerate(predictions): predicted_class_index = np.argmax(prediction) confidence = float(prediction[predicted_class_index]) logger.info(f"Result {i}: class={predicted_class_index}, confidence={confidence:.4f}") results.append({"word": int(predicted_class_index), "confidence": confidence}) elapsed_time = time.time() - start_time logger.info(f"Inference completed in {elapsed_time:.3f} seconds") return results except Exception as e: logger.error(f"Error during inference: {str(e)}", exc_info=True) return [{"error": str(e)}] finally: try: if temp_wav_path and os.path.exists(temp_wav_path): os.remove(temp_wav_path) logger.debug(f"Removed temporary file: {temp_wav_path}") if temp_dir and os.path.exists(temp_dir): os.rmdir(temp_dir) logger.debug(f"Removed temporary directory: {temp_dir}") except Exception as cleanup_error: logger.error(f"Error during cleanup: {str(cleanup_error)}")