| | import tensorflow as tf
|
| | import numpy as np
|
| | import os
|
| | import io
|
| | import tempfile
|
| | import logging
|
| | import time
|
| | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| |
|
| | from keras.models import load_model
|
| | from keras.layers import Layer
|
| |
|
| |
|
| | logging.basicConfig(
|
| | level=logging.INFO,
|
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| | handlers=[
|
| | logging.StreamHandler()
|
| | ]
|
| | )
|
| | logger = logging.getLogger('audio_inference')
|
| |
|
| | class WavToMelLayer(Layer):
|
| | def __init__(self, sample_rate=16000, n_mels=128, fft_size=1024, hop_size=512, **kwargs):
|
| | super(WavToMelLayer, self).__init__(**kwargs)
|
| | self.sample_rate = sample_rate
|
| | self.n_mels = n_mels
|
| | self.fft_size = fft_size
|
| | self.hop_size = hop_size
|
| |
|
| | def call(self, inputs):
|
| | def process_audio(input_path):
|
| | logger.debug(f"Processing audio file: {input_path}")
|
| | try:
|
| | audio = tf.io.read_file(input_path)
|
| | audio, sr = tf.audio.decode_wav(audio, desired_channels=1)
|
| | logger.debug(f"Decoded WAV file with sample rate: {sr}, shape: {audio.shape}")
|
| | audio = tf.squeeze(audio, axis=-1)
|
| |
|
| | stft = tf.signal.stft(audio, frame_length=self.fft_size, frame_step=self.hop_size)
|
| | logger.debug(f"STFT shape: {stft.shape}")
|
| | spectrogram = tf.abs(stft) ** 2
|
| |
|
| | mel_weights = tf.signal.linear_to_mel_weight_matrix(
|
| | self.n_mels, self.fft_size // 2 + 1, self.sample_rate, 20.0, 4000.0
|
| | )
|
| | mel_spectrogram = tf.tensordot(spectrogram, mel_weights, axes=1)
|
| | mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
|
| | logger.debug(f"Mel spectrogram shape: {mel_spectrogram.shape}")
|
| |
|
| | mel_spectrogram = tf.image.resize(mel_spectrogram[..., tf.newaxis], [128, 128])
|
| | mel_spectrogram = tf.image.grayscale_to_rgb(mel_spectrogram)
|
| | logger.debug(f"Final mel spectrogram shape: {mel_spectrogram.shape}")
|
| |
|
| | return mel_spectrogram
|
| | except Exception as e:
|
| | logger.error(f"Error in process_audio: {str(e)}")
|
| | raise
|
| |
|
| | return tf.map_fn(process_audio, inputs, dtype=tf.float32)
|
| |
|
| | def get_config(self):
|
| | config = super(WavToMelLayer, self).get_config()
|
| | config.update({
|
| | "sample_rate": self.sample_rate,
|
| | "n_mels": self.n_mels,
|
| | "fft_size": self.fft_size,
|
| | "hop_size": self.hop_size
|
| | })
|
| | return config
|
| |
|
| | class EndpointHandler:
|
| | def __init__(self, model_dir):
|
| | logger.info("Initializing EndpointHandler")
|
| | if model_dir is None:
|
| | model_dir = os.path.dirname(os.path.abspath(__file__))
|
| | logger.info(f"Model directory not provided, using current directory: {model_dir}")
|
| | else:
|
| | logger.info(f"Using provided model directory: {model_dir}")
|
| |
|
| | model_path = os.path.join(model_dir, "model/bestModel.keras")
|
| | logger.info(f"Loading model from: {model_path}")
|
| |
|
| | try:
|
| | self.model = load_model(model_path, custom_objects={"WavToMelLayer": WavToMelLayer})
|
| | logger.info(f"Model loaded successfully: {self.model.summary()}")
|
| | except Exception as e:
|
| | logger.error(f"Failed to load model: {str(e)}")
|
| | raise
|
| |
|
| | def __call__(self, requests):
|
| | start_time = time.time()
|
| | logger.info("Processing inference request")
|
| | temp_dir = None
|
| | temp_wav_path = None
|
| | input_yeah = requests['inputs']
|
| |
|
| | try:
|
| | temp_dir = tempfile.mkdtemp()
|
| | temp_wav_path = os.path.join(temp_dir, "wav_input.wav")
|
| | logger.info(f"Created temporary directory: {temp_dir}")
|
| |
|
| | logger.info(requests)
|
| |
|
| | if not isinstance(input_yeah, bytes):
|
| | logger.error(f"Expected bytes, got {type(input_yeah)}")
|
| | return [{"error": f"Invalid input type: {type(input_yeah)}, expected bytes"}]
|
| |
|
| | logger.debug(f"Writing {len(input_yeah)} bytes to temporary file: {temp_wav_path}")
|
| | with open(temp_wav_path, "wb") as f:
|
| | f.write(input_yeah)
|
| |
|
| | if not os.path.exists(temp_wav_path):
|
| | logger.error(f"Failed to create temporary WAV file: {temp_wav_path}")
|
| | return [{"error": "Failed to create temporary WAV file"}]
|
| |
|
| | logger.debug(f"File size: {os.path.getsize(temp_wav_path)} bytes")
|
| |
|
| | inputs = tf.constant([temp_wav_path])
|
| | logger.info("Running model prediction")
|
| |
|
| | predictions = self.model.predict(inputs)
|
| | logger.debug(f"Raw predictions: {predictions}")
|
| |
|
| | results = []
|
| | for i, prediction in enumerate(predictions):
|
| | predicted_class_index = np.argmax(prediction)
|
| | confidence = float(prediction[predicted_class_index])
|
| | logger.info(f"Result {i}: class={predicted_class_index}, confidence={confidence:.4f}")
|
| | results.append({"word": int(predicted_class_index), "confidence": confidence})
|
| |
|
| | elapsed_time = time.time() - start_time
|
| | logger.info(f"Inference completed in {elapsed_time:.3f} seconds")
|
| | return results
|
| |
|
| | except Exception as e:
|
| | logger.error(f"Error during inference: {str(e)}", exc_info=True)
|
| | return [{"error": str(e)}]
|
| |
|
| | finally:
|
| | try:
|
| | if temp_wav_path and os.path.exists(temp_wav_path):
|
| | os.remove(temp_wav_path)
|
| | logger.debug(f"Removed temporary file: {temp_wav_path}")
|
| | if temp_dir and os.path.exists(temp_dir):
|
| | os.rmdir(temp_dir)
|
| | logger.debug(f"Removed temporary directory: {temp_dir}")
|
| | except Exception as cleanup_error:
|
| | logger.error(f"Error during cleanup: {str(cleanup_error)}") |