File size: 6,305 Bytes
dd97ba3 a806fea bf1c3a7 ecacdf1 a806fea dd97ba3 bf1c3a7 a806fea bf1c3a7 a806fea bf1c3a7 a806fea d13deea bf1c3a7 dd97ba3 a806fea dd97ba3 a806fea bf1c3a7 a806fea bf1c3a7 23d869f d13deea a806fea bf1c3a7 ecacdf1 bf1c3a7 5dae612 d13deea 23d869f a806fea d13deea bf1c3a7 d13deea bf1c3a7 a806fea bf1c3a7 a806fea d13deea bf1c3a7 a806fea bf1c3a7 23d869f bf1c3a7 d13deea a806fea 23d869f a806fea d13deea 5dae612 a806fea d13deea a806fea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import tensorflow as tf
import numpy as np
import os
import io
import tempfile
import logging
import time
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from keras.models import load_model
from keras.layers import Layer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logger = logging.getLogger('audio_inference')
class WavToMelLayer(Layer):
def __init__(self, sample_rate=16000, n_mels=128, fft_size=1024, hop_size=512, **kwargs):
super(WavToMelLayer, self).__init__(**kwargs)
self.sample_rate = sample_rate
self.n_mels = n_mels
self.fft_size = fft_size
self.hop_size = hop_size
def call(self, inputs):
def process_audio(input_path):
logger.debug(f"Processing audio file: {input_path}")
try:
audio = tf.io.read_file(input_path)
audio, sr = tf.audio.decode_wav(audio, desired_channels=1)
logger.debug(f"Decoded WAV file with sample rate: {sr}, shape: {audio.shape}")
audio = tf.squeeze(audio, axis=-1)
stft = tf.signal.stft(audio, frame_length=self.fft_size, frame_step=self.hop_size)
logger.debug(f"STFT shape: {stft.shape}")
spectrogram = tf.abs(stft) ** 2
mel_weights = tf.signal.linear_to_mel_weight_matrix(
self.n_mels, self.fft_size // 2 + 1, self.sample_rate, 20.0, 4000.0
)
mel_spectrogram = tf.tensordot(spectrogram, mel_weights, axes=1)
mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
logger.debug(f"Mel spectrogram shape: {mel_spectrogram.shape}")
mel_spectrogram = tf.image.resize(mel_spectrogram[..., tf.newaxis], [128, 128])
mel_spectrogram = tf.image.grayscale_to_rgb(mel_spectrogram)
logger.debug(f"Final mel spectrogram shape: {mel_spectrogram.shape}")
return mel_spectrogram
except Exception as e:
logger.error(f"Error in process_audio: {str(e)}")
raise
return tf.map_fn(process_audio, inputs, dtype=tf.float32)
def get_config(self):
config = super(WavToMelLayer, self).get_config()
config.update({
"sample_rate": self.sample_rate,
"n_mels": self.n_mels,
"fft_size": self.fft_size,
"hop_size": self.hop_size
})
return config
class EndpointHandler:
def __init__(self, model_dir):
logger.info("Initializing EndpointHandler")
if model_dir is None:
model_dir = os.path.dirname(os.path.abspath(__file__))
logger.info(f"Model directory not provided, using current directory: {model_dir}")
else:
logger.info(f"Using provided model directory: {model_dir}")
model_path = os.path.join(model_dir, "model/bestModel.keras")
logger.info(f"Loading model from: {model_path}")
try:
self.model = load_model(model_path, custom_objects={"WavToMelLayer": WavToMelLayer})
logger.info(f"Model loaded successfully: {self.model.summary()}")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
def __call__(self, requests):
start_time = time.time()
logger.info("Processing inference request")
temp_dir = None
temp_wav_path = None
input_yeah = requests['inputs']
try:
temp_dir = tempfile.mkdtemp()
temp_wav_path = os.path.join(temp_dir, "wav_input.wav")
logger.info(f"Created temporary directory: {temp_dir}")
logger.info(requests)
if not isinstance(input_yeah, bytes):
logger.error(f"Expected bytes, got {type(input_yeah)}")
return [{"error": f"Invalid input type: {type(input_yeah)}, expected bytes"}]
logger.debug(f"Writing {len(input_yeah)} bytes to temporary file: {temp_wav_path}")
with open(temp_wav_path, "wb") as f:
f.write(input_yeah)
if not os.path.exists(temp_wav_path):
logger.error(f"Failed to create temporary WAV file: {temp_wav_path}")
return [{"error": "Failed to create temporary WAV file"}]
logger.debug(f"File size: {os.path.getsize(temp_wav_path)} bytes")
inputs = tf.constant([temp_wav_path])
logger.info("Running model prediction")
predictions = self.model.predict(inputs)
logger.debug(f"Raw predictions: {predictions}")
results = []
for i, prediction in enumerate(predictions):
predicted_class_index = np.argmax(prediction)
confidence = float(prediction[predicted_class_index])
logger.info(f"Result {i}: class={predicted_class_index}, confidence={confidence:.4f}")
results.append({"word": int(predicted_class_index), "confidence": confidence})
elapsed_time = time.time() - start_time
logger.info(f"Inference completed in {elapsed_time:.3f} seconds")
return results
except Exception as e:
logger.error(f"Error during inference: {str(e)}", exc_info=True)
return [{"error": str(e)}]
finally:
try:
if temp_wav_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
logger.debug(f"Removed temporary file: {temp_wav_path}")
if temp_dir and os.path.exists(temp_dir):
os.rmdir(temp_dir)
logger.debug(f"Removed temporary directory: {temp_dir}")
except Exception as cleanup_error:
logger.error(f"Error during cleanup: {str(cleanup_error)}") |