dreemer09 commited on
Commit
a806fea
·
1 Parent(s): 7e66a7c

alksdhlahk

Browse files
Files changed (1) hide show
  1. handler.py +180 -45
handler.py CHANGED
@@ -1,71 +1,206 @@
1
  import tensorflow as tf
2
- import os
3
- import librosa
4
  import numpy as np
5
- import time
 
6
  import tempfile
7
-
 
 
8
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
9
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class EndpointHandler:
12
  def __init__(self, model_dir):
 
13
  if model_dir is None:
14
  model_dir = os.path.dirname(os.path.abspath(__file__))
15
-
16
- # Model path
17
- model_path = os.path.join(model_dir, "model/speechModelv2.keras")
18
-
19
- # Load the model with custom_objects to handle any custom layers
20
- self.model = tf.keras.models.load_model(model_path)
21
-
22
- def preprocess_audio(self, audio_path):
23
- SAMPLE_RATE = 16000
24
- N_MELS = 128
25
- FFT_SIZE = 1024
26
- HOP_SIZE = 512
27
-
28
- audio, sr = librosa.load(file_path, sr=SAMPLE_RATE)
29
- mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=N_MELS, n_fft=FFT_SIZE, hop_length=HOP_SIZE)
30
- log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
31
-
32
- # Ensure fixed size (128x128)
33
- if log_mel_spectrogram.shape[1] < 128:
34
- log_mel_spectrogram = np.pad(log_mel_spectrogram, ((0, 0), (0, 128 - log_mel_spectrogram.shape[1])), mode='constant')
35
  else:
36
- log_mel_spectrogram = log_mel_spectrogram[:, :128]
37
-
38
- return np.expand_dims(log_mel_spectrogram, axis=[0, -1])
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def __call__(self, requests):
 
 
41
  temp_dir = None
42
  temp_wav_path = None
43
 
44
  try:
45
- # Create temporary directory and file
 
 
 
 
 
 
 
 
 
 
46
  temp_dir = tempfile.mkdtemp()
47
- temp_wav_path = os.path.join(temp_dir, "wav_input.wav")
 
48
 
49
- # Write audio data to temporary file
 
50
  with open(temp_wav_path, "wb") as f:
51
- f.write(requests)
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Preprocess audio
54
- input_data = self.preprocess_audio(temp_wav_path)
55
- predictions = self.model.predict(input_data)
56
- predicted_class = int(np.argmax(predictions, axis=1)[0])
57
- confidence = float(predictions[0][predicted_class])
 
 
 
 
 
 
 
58
 
59
- # Prepare response
60
- response = {"class_id": predicted_class, "confidence": confidence}
61
- return response
62
 
63
  except Exception as e:
64
- return {"error": str(e)}
 
65
 
66
  finally:
67
  # Clean up temporary files
68
- if temp_wav_path and os.path.exists(temp_wav_path):
69
- os.remove(temp_wav_path)
70
- if temp_dir and os.path.exists(temp_dir):
71
- os.rmdir(temp_dir)
 
 
 
 
 
 
1
  import tensorflow as tf
 
 
2
  import numpy as np
3
+ import os
4
+ import io
5
  import tempfile
6
+ import logging
7
+ import time
8
+ import json
9
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
10
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
11
 
12
+ from tensorflow.keras.models import load_model
13
+ from tensorflow.keras.layers import (
14
+ Input, Conv2D, GlobalAveragePooling2D, Dense, Dropout, Add, LeakyReLU,
15
+ MaxPooling2D, SpatialDropout2D, LayerNormalization, Layer, Multiply, Reshape
16
+ )
17
+
18
+ # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
22
+ handlers=[
23
+ logging.StreamHandler()
24
+ ]
25
+ )
26
+ logger = logging.getLogger('speech_recognition')
27
+
28
+ class AudioPreprocessingLayer(Layer):
29
+ def __init__(self, sample_rate=16000, n_mels=128, fft_size=1024, hop_size=512, **kwargs):
30
+ super(AudioPreprocessingLayer, self).__init__(**kwargs)
31
+ self.sample_rate = sample_rate
32
+ self.n_mels = n_mels
33
+ self.fft_size = fft_size
34
+ self.hop_size = hop_size
35
+
36
+ def call(self, inputs):
37
+ def process_audio(input_path):
38
+ logger.debug(f"Processing audio file: {input_path}")
39
+ try:
40
+ audio = tf.io.read_file(input_path)
41
+ audio, sr = tf.audio.decode_wav(audio, desired_channels=1)
42
+ logger.debug(f"Decoded WAV file with sample rate: {sr}, shape: {audio.shape}")
43
+ audio = tf.squeeze(audio, axis=-1)
44
+
45
+ stft = tf.signal.stft(audio, frame_length=self.fft_size, frame_step=self.hop_size)
46
+ logger.debug(f"STFT shape: {stft.shape}")
47
+ spectrogram = tf.abs(stft) ** 2
48
+
49
+ # Create mel filter bank
50
+ mel_weights = tf.signal.linear_to_mel_weight_matrix(
51
+ self.n_mels, self.fft_size // 2 + 1, self.sample_rate, 20.0, 4000.0
52
+ )
53
+ mel_spectrogram = tf.tensordot(spectrogram, mel_weights, axes=1)
54
+ mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
55
+ logger.debug(f"Mel spectrogram shape: {mel_spectrogram.shape}")
56
+
57
+ # Resize to model's expected input size and keep as single channel
58
+ mel_spectrogram = tf.image.resize(mel_spectrogram[..., tf.newaxis], [128, 128])
59
+ logger.debug(f"Final mel spectrogram shape: {mel_spectrogram.shape}")
60
+
61
+ # Normalize to range 0-1
62
+ mel_spectrogram = (mel_spectrogram - tf.reduce_min(mel_spectrogram)) / (
63
+ tf.reduce_max(mel_spectrogram) - tf.reduce_min(mel_spectrogram) + 1e-6)
64
+
65
+ return mel_spectrogram
66
+ except Exception as e:
67
+ logger.error(f"Error in process_audio: {str(e)}")
68
+ raise
69
+
70
+ return tf.map_fn(process_audio, inputs, dtype=tf.float32)
71
+
72
+ def get_config(self):
73
+ config = super(AudioPreprocessingLayer, self).get_config()
74
+ config.update({
75
+ "sample_rate": self.sample_rate,
76
+ "n_mels": self.n_mels,
77
+ "fft_size": self.fft_size,
78
+ "hop_size": self.hop_size
79
+ })
80
+ return config
81
+
82
+ # Define model architecture components for loading
83
+ def se_block(x, ratio=8):
84
+ filters = x.shape[-1]
85
+ squeeze = GlobalAveragePooling2D()(x)
86
+ excitation = Dense(filters // ratio, activation="relu")(squeeze)
87
+ excitation = Dense(filters, activation="sigmoid")(excitation)
88
+ excitation = Reshape((1, 1, filters))(excitation)
89
+ return Multiply()([x, excitation])
90
+
91
+ def residual_block(x, filters):
92
+ shortcut = x
93
+ x = Conv2D(filters, (3, 3), padding="same", use_bias=False)(x)
94
+ x = LayerNormalization()(x)
95
+ x = LeakyReLU()(x)
96
+
97
+ x = Conv2D(filters, (3, 3), padding="same", use_bias=False)(x)
98
+ x = LayerNormalization()(x)
99
+ x = se_block(x)
100
+
101
+ if shortcut.shape[-1] != filters:
102
+ shortcut = Conv2D(filters, (1, 1), padding="same", use_bias=False)(shortcut)
103
+ shortcut = LayerNormalization()(shortcut)
104
+
105
+ x = Add()([x, shortcut])
106
+ x = LeakyReLU()(x)
107
+ x = SpatialDropout2D(0.2)(x)
108
+ return x
109
+
110
  class EndpointHandler:
111
  def __init__(self, model_dir):
112
+ logger.info("Initializing Speech Recognition EndpointHandler")
113
  if model_dir is None:
114
  model_dir = os.path.dirname(os.path.abspath(__file__))
115
+ logger.info(f"Model directory not provided, using current directory: {model_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  else:
117
+ logger.info(f"Using provided model directory: {model_dir}")
118
+
119
+ # Load the model
120
+ model_path = os.path.join(model_dir, "model/speech_model.keras")
121
+ logger.info(f"Loading model from: {model_path}")
122
+
123
+ try:
124
+ # Load the model with custom objects
125
+ custom_objects = {
126
+ "AudioPreprocessingLayer": AudioPreprocessingLayer
127
+ }
128
+ self.model = load_model(model_path, custom_objects=custom_objects)
129
+ logger.info(f"Model loaded successfully with input shape: {self.model.input_shape}")
130
+
131
+ except Exception as e:
132
+ logger.error(f"Failed to initialize endpoint: {str(e)}", exc_info=True)
133
+ raise
134
+
135
  def __call__(self, requests):
136
+ start_time = time.time()
137
+ logger.info("Processing speech recognition request")
138
  temp_dir = None
139
  temp_wav_path = None
140
 
141
  try:
142
+ # Extract input audio bytes
143
+ input_audio = requests.get('inputs', None)
144
+ if input_audio is None:
145
+ logger.error("No input data provided")
146
+ return [{"error": "No input data provided"}]
147
+
148
+ if not isinstance(input_audio, bytes):
149
+ logger.error(f"Expected bytes input, got {type(input_audio)}")
150
+ return [{"error": f"Invalid input type: {type(input_audio)}, expected bytes"}]
151
+
152
+ # Create temporary file for audio processing
153
  temp_dir = tempfile.mkdtemp()
154
+ temp_wav_path = os.path.join(temp_dir, "speech_input.wav")
155
+ logger.info(f"Created temporary directory: {temp_dir}")
156
 
157
+ # Write audio bytes to temporary file
158
+ logger.debug(f"Writing {len(input_audio)} bytes to temporary file: {temp_wav_path}")
159
  with open(temp_wav_path, "wb") as f:
160
+ f.write(input_audio)
161
+
162
+ if not os.path.exists(temp_wav_path):
163
+ logger.error(f"Failed to create temporary WAV file: {temp_wav_path}")
164
+ return [{"error": "Failed to create temporary WAV file"}]
165
+
166
+ logger.debug(f"File size: {os.path.getsize(temp_wav_path)} bytes")
167
+
168
+ # Preprocess and run inference
169
+ inputs = tf.constant([temp_wav_path])
170
+ logger.info("Running model prediction")
171
+
172
+ predictions = self.model.predict(inputs)
173
+ logger.debug(f"Raw predictions shape: {predictions.shape}")
174
 
175
+ # Process results
176
+ results = []
177
+ for i, prediction in enumerate(predictions):
178
+ # Get top 3 predictions
179
+ top_indices = np.argsort(prediction)[-3:][::-1]
180
+ predictions_list = []
181
+
182
+ for idx in top_indices:
183
+ results.append({
184
+ "word": int(top_indices[0]),
185
+ "confidence": float(prediction[top_indices[0]])
186
+ })
187
 
188
+ elapsed_time = time.time() - start_time
189
+ logger.info(f"Speech recognition completed in {elapsed_time:.3f} seconds")
190
+ return results
191
 
192
  except Exception as e:
193
+ logger.error(f"Error during inference: {str(e)}", exc_info=True)
194
+ return [{"error": str(e)}]
195
 
196
  finally:
197
  # Clean up temporary files
198
+ try:
199
+ if temp_wav_path and os.path.exists(temp_wav_path):
200
+ os.remove(temp_wav_path)
201
+ logger.debug(f"Removed temporary file: {temp_wav_path}")
202
+ if temp_dir and os.path.exists(temp_dir):
203
+ os.rmdir(temp_dir)
204
+ logger.debug(f"Removed temporary directory: {temp_dir}")
205
+ except Exception as cleanup_error:
206
+ logger.error(f"Error during cleanup: {str(cleanup_error)}")