samacebuV0 / handler.py
dreemer09
agafgfgdgs
bf1c3a7
import tensorflow as tf
import numpy as np
import os
import io
import tempfile
import logging
import time
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from keras.models import load_model
from keras.layers import Layer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logger = logging.getLogger('audio_inference')
class WavToMelLayer(Layer):
def __init__(self, sample_rate=16000, n_mels=128, fft_size=1024, hop_size=512, **kwargs):
super(WavToMelLayer, self).__init__(**kwargs)
self.sample_rate = sample_rate
self.n_mels = n_mels
self.fft_size = fft_size
self.hop_size = hop_size
def call(self, inputs):
def process_audio(input_path):
logger.debug(f"Processing audio file: {input_path}")
try:
audio = tf.io.read_file(input_path)
audio, sr = tf.audio.decode_wav(audio, desired_channels=1)
logger.debug(f"Decoded WAV file with sample rate: {sr}, shape: {audio.shape}")
audio = tf.squeeze(audio, axis=-1)
stft = tf.signal.stft(audio, frame_length=self.fft_size, frame_step=self.hop_size)
logger.debug(f"STFT shape: {stft.shape}")
spectrogram = tf.abs(stft) ** 2
mel_weights = tf.signal.linear_to_mel_weight_matrix(
self.n_mels, self.fft_size // 2 + 1, self.sample_rate, 20.0, 4000.0
)
mel_spectrogram = tf.tensordot(spectrogram, mel_weights, axes=1)
mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
logger.debug(f"Mel spectrogram shape: {mel_spectrogram.shape}")
mel_spectrogram = tf.image.resize(mel_spectrogram[..., tf.newaxis], [128, 128])
mel_spectrogram = tf.image.grayscale_to_rgb(mel_spectrogram)
logger.debug(f"Final mel spectrogram shape: {mel_spectrogram.shape}")
return mel_spectrogram
except Exception as e:
logger.error(f"Error in process_audio: {str(e)}")
raise
return tf.map_fn(process_audio, inputs, dtype=tf.float32)
def get_config(self):
config = super(WavToMelLayer, self).get_config()
config.update({
"sample_rate": self.sample_rate,
"n_mels": self.n_mels,
"fft_size": self.fft_size,
"hop_size": self.hop_size
})
return config
class EndpointHandler:
def __init__(self, model_dir):
logger.info("Initializing EndpointHandler")
if model_dir is None:
model_dir = os.path.dirname(os.path.abspath(__file__))
logger.info(f"Model directory not provided, using current directory: {model_dir}")
else:
logger.info(f"Using provided model directory: {model_dir}")
model_path = os.path.join(model_dir, "model/bestModel.keras")
logger.info(f"Loading model from: {model_path}")
try:
self.model = load_model(model_path, custom_objects={"WavToMelLayer": WavToMelLayer})
logger.info(f"Model loaded successfully: {self.model.summary()}")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
def __call__(self, requests):
start_time = time.time()
logger.info("Processing inference request")
temp_dir = None
temp_wav_path = None
input_yeah = requests['inputs']
try:
temp_dir = tempfile.mkdtemp()
temp_wav_path = os.path.join(temp_dir, "wav_input.wav")
logger.info(f"Created temporary directory: {temp_dir}")
logger.info(requests)
if not isinstance(input_yeah, bytes):
logger.error(f"Expected bytes, got {type(input_yeah)}")
return [{"error": f"Invalid input type: {type(input_yeah)}, expected bytes"}]
logger.debug(f"Writing {len(input_yeah)} bytes to temporary file: {temp_wav_path}")
with open(temp_wav_path, "wb") as f:
f.write(input_yeah)
if not os.path.exists(temp_wav_path):
logger.error(f"Failed to create temporary WAV file: {temp_wav_path}")
return [{"error": "Failed to create temporary WAV file"}]
logger.debug(f"File size: {os.path.getsize(temp_wav_path)} bytes")
inputs = tf.constant([temp_wav_path])
logger.info("Running model prediction")
predictions = self.model.predict(inputs)
logger.debug(f"Raw predictions: {predictions}")
results = []
for i, prediction in enumerate(predictions):
predicted_class_index = np.argmax(prediction)
confidence = float(prediction[predicted_class_index])
logger.info(f"Result {i}: class={predicted_class_index}, confidence={confidence:.4f}")
results.append({"word": int(predicted_class_index), "confidence": confidence})
elapsed_time = time.time() - start_time
logger.info(f"Inference completed in {elapsed_time:.3f} seconds")
return results
except Exception as e:
logger.error(f"Error during inference: {str(e)}", exc_info=True)
return [{"error": str(e)}]
finally:
try:
if temp_wav_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
logger.debug(f"Removed temporary file: {temp_wav_path}")
if temp_dir and os.path.exists(temp_dir):
os.rmdir(temp_dir)
logger.debug(f"Removed temporary directory: {temp_dir}")
except Exception as cleanup_error:
logger.error(f"Error during cleanup: {str(cleanup_error)}")