File size: 6,305 Bytes
dd97ba3
 
a806fea
bf1c3a7
ecacdf1
a806fea
 
dd97ba3
 
 
bf1c3a7
 
a806fea
 
 
 
 
 
 
 
 
bf1c3a7
a806fea
bf1c3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a806fea
d13deea
 
bf1c3a7
dd97ba3
 
a806fea
dd97ba3
a806fea
 
bf1c3a7
a806fea
 
 
bf1c3a7
 
23d869f
 
 
 
d13deea
a806fea
bf1c3a7
ecacdf1
 
bf1c3a7
 
5dae612
d13deea
23d869f
a806fea
d13deea
bf1c3a7
 
 
 
 
 
 
d13deea
bf1c3a7
a806fea
 
 
 
 
bf1c3a7
 
 
a806fea
d13deea
bf1c3a7
 
 
a806fea
 
bf1c3a7
23d869f
 
bf1c3a7
d13deea
a806fea
23d869f
a806fea
d13deea
5dae612
a806fea
 
d13deea
 
a806fea
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import tensorflow as tf
import numpy as np
import os
import io
import tempfile
import logging
import time
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

from keras.models import load_model
from keras.layers import Layer

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger('audio_inference')

class WavToMelLayer(Layer):
    def __init__(self, sample_rate=16000, n_mels=128, fft_size=1024, hop_size=512, **kwargs):
        super(WavToMelLayer, self).__init__(**kwargs)
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.fft_size = fft_size
        self.hop_size = hop_size

    def call(self, inputs):
        def process_audio(input_path):
            logger.debug(f"Processing audio file: {input_path}")
            try:
                audio = tf.io.read_file(input_path)
                audio, sr = tf.audio.decode_wav(audio, desired_channels=1)
                logger.debug(f"Decoded WAV file with sample rate: {sr}, shape: {audio.shape}")
                audio = tf.squeeze(audio, axis=-1)

                stft = tf.signal.stft(audio, frame_length=self.fft_size, frame_step=self.hop_size)
                logger.debug(f"STFT shape: {stft.shape}")
                spectrogram = tf.abs(stft) ** 2

                mel_weights = tf.signal.linear_to_mel_weight_matrix(
                    self.n_mels, self.fft_size // 2 + 1, self.sample_rate, 20.0, 4000.0
                )
                mel_spectrogram = tf.tensordot(spectrogram, mel_weights, axes=1)
                mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
                logger.debug(f"Mel spectrogram shape: {mel_spectrogram.shape}")

                mel_spectrogram = tf.image.resize(mel_spectrogram[..., tf.newaxis], [128, 128])
                mel_spectrogram = tf.image.grayscale_to_rgb(mel_spectrogram)
                logger.debug(f"Final mel spectrogram shape: {mel_spectrogram.shape}")

                return mel_spectrogram
            except Exception as e:
                logger.error(f"Error in process_audio: {str(e)}")
                raise

        return tf.map_fn(process_audio, inputs, dtype=tf.float32)

    def get_config(self):
        config = super(WavToMelLayer, self).get_config()
        config.update({
            "sample_rate": self.sample_rate,
            "n_mels": self.n_mels,
            "fft_size": self.fft_size,
            "hop_size": self.hop_size
        })
        return config

class EndpointHandler:
    def __init__(self, model_dir):
        logger.info("Initializing EndpointHandler")
        if model_dir is None:
            model_dir = os.path.dirname(os.path.abspath(__file__))
            logger.info(f"Model directory not provided, using current directory: {model_dir}")
        else:
            logger.info(f"Using provided model directory: {model_dir}")
            
        model_path = os.path.join(model_dir, "model/bestModel.keras")
        logger.info(f"Loading model from: {model_path}")
        
        try:
            self.model = load_model(model_path, custom_objects={"WavToMelLayer": WavToMelLayer})
            logger.info(f"Model loaded successfully: {self.model.summary()}")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise
    
    def __call__(self, requests):
        start_time = time.time()
        logger.info("Processing inference request")
        temp_dir = None
        temp_wav_path = None
        input_yeah = requests['inputs']

        try:
            temp_dir = tempfile.mkdtemp()
            temp_wav_path = os.path.join(temp_dir, "wav_input.wav")
            logger.info(f"Created temporary directory: {temp_dir}")
            
            logger.info(requests)

            if not isinstance(input_yeah, bytes):
                logger.error(f"Expected bytes, got {type(input_yeah)}")
                return [{"error": f"Invalid input type: {type(input_yeah)}, expected bytes"}]
            
            logger.debug(f"Writing {len(input_yeah)} bytes to temporary file: {temp_wav_path}")
            with open(temp_wav_path, "wb") as f:
                f.write(input_yeah)
            
            if not os.path.exists(temp_wav_path):
                logger.error(f"Failed to create temporary WAV file: {temp_wav_path}")
                return [{"error": "Failed to create temporary WAV file"}]
            
            logger.debug(f"File size: {os.path.getsize(temp_wav_path)} bytes")

            inputs = tf.constant([temp_wav_path])
            logger.info("Running model prediction")
            
            predictions = self.model.predict(inputs)
            logger.debug(f"Raw predictions: {predictions}")
            
            results = []
            for i, prediction in enumerate(predictions):
                predicted_class_index = np.argmax(prediction)
                confidence = float(prediction[predicted_class_index])
                logger.info(f"Result {i}: class={predicted_class_index}, confidence={confidence:.4f}")
                results.append({"word": int(predicted_class_index), "confidence": confidence})
            
            elapsed_time = time.time() - start_time
            logger.info(f"Inference completed in {elapsed_time:.3f} seconds")
            return results
            
        except Exception as e:
            logger.error(f"Error during inference: {str(e)}", exc_info=True)
            return [{"error": str(e)}]
        
        finally:
            try:
                if temp_wav_path and os.path.exists(temp_wav_path):
                    os.remove(temp_wav_path)
                    logger.debug(f"Removed temporary file: {temp_wav_path}")
                if temp_dir and os.path.exists(temp_dir):
                    os.rmdir(temp_dir)
                    logger.debug(f"Removed temporary directory: {temp_dir}")
            except Exception as cleanup_error:
                logger.error(f"Error during cleanup: {str(cleanup_error)}")