dreemer09 commited on
Commit
23d869f
·
1 Parent(s): 7cd7278

ahsdjkhakdaklshd

Browse files
Files changed (1) hide show
  1. handler.py +90 -134
handler.py CHANGED
@@ -1,20 +1,14 @@
1
  import tensorflow as tf
2
  import numpy as np
3
  import os
4
- import io
5
  import tempfile
6
  import logging
7
  import time
8
- import json
9
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
10
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
11
 
12
  from tensorflow.keras.models import load_model
13
- from tensorflow.keras.layers import (
14
- Input, Conv2D, GlobalAveragePooling2D, Dense, Dropout, Add, LeakyReLU,
15
- MaxPooling2D, SpatialDropout2D, LayerNormalization, Layer, Multiply, Reshape,
16
- InputLayer
17
- )
18
 
19
  # Configure logging
20
  logging.basicConfig(
@@ -24,99 +18,13 @@ logging.basicConfig(
24
  logging.StreamHandler()
25
  ]
26
  )
27
- logger = logging.getLogger('speech_recognition')
28
-
29
- # Custom InputLayer to handle batch_shape compatibility issue
30
- class CustomInputLayer(InputLayer):
31
- @classmethod
32
- def from_config(cls, config):
33
- # Convert batch_shape to input_shape if present
34
- if 'batch_shape' in config:
35
- config['input_shape'] = config['batch_shape'][1:]
36
- del config['batch_shape']
37
- return cls(**config)
38
-
39
- class AudioPreprocessingLayer(Layer):
40
- def __init__(self, sample_rate=16000, n_mels=128, fft_size=1024, hop_size=512, **kwargs):
41
- super(AudioPreprocessingLayer, self).__init__(**kwargs)
42
- self.sample_rate = sample_rate
43
- self.n_mels = n_mels
44
- self.fft_size = fft_size
45
- self.hop_size = hop_size
46
-
47
- def call(self, inputs):
48
- def process_audio(input_path):
49
- logger.debug(f"Processing audio file: {input_path}")
50
- try:
51
- audio = tf.io.read_file(input_path)
52
- audio, sr = tf.audio.decode_wav(audio, desired_channels=1)
53
- logger.debug(f"Decoded WAV file with sample rate: {sr}, shape: {audio.shape}")
54
- audio = tf.squeeze(audio, axis=-1)
55
-
56
- stft = tf.signal.stft(audio, frame_length=self.fft_size, frame_step=self.hop_size)
57
- logger.debug(f"STFT shape: {stft.shape}")
58
- spectrogram = tf.abs(stft) ** 2
59
-
60
- # Create mel filter bank
61
- mel_weights = tf.signal.linear_to_mel_weight_matrix(
62
- self.n_mels, self.fft_size // 2 + 1, self.sample_rate, 20.0, 4000.0
63
- )
64
- mel_spectrogram = tf.tensordot(spectrogram, mel_weights, axes=1)
65
- mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
66
- logger.debug(f"Mel spectrogram shape: {mel_spectrogram.shape}")
67
-
68
- # Resize to model's expected input size and keep as single channel
69
- mel_spectrogram = tf.image.resize(mel_spectrogram[..., tf.newaxis], [128, 128])
70
- logger.debug(f"Final mel spectrogram shape: {mel_spectrogram.shape}")
71
-
72
- # Normalize to range 0-1
73
- mel_spectrogram = (mel_spectrogram - tf.reduce_min(mel_spectrogram)) / (
74
- tf.reduce_max(mel_spectrogram) - tf.reduce_min(mel_spectrogram) + 1e-6)
75
-
76
- return mel_spectrogram
77
- except Exception as e:
78
- logger.error(f"Error in process_audio: {str(e)}")
79
- raise
80
-
81
- return tf.map_fn(process_audio, inputs, dtype=tf.float32)
82
-
83
- def get_config(self):
84
- config = super(AudioPreprocessingLayer, self).get_config()
85
- config.update({
86
- "sample_rate": self.sample_rate,
87
- "n_mels": self.n_mels,
88
- "fft_size": self.fft_size,
89
- "hop_size": self.hop_size
90
- })
91
- return config
92
-
93
- # Define model architecture components for loading
94
- def se_block(x, ratio=8):
95
- filters = x.shape[-1]
96
- squeeze = GlobalAveragePooling2D()(x)
97
- excitation = Dense(filters // ratio, activation="relu")(squeeze)
98
- excitation = Dense(filters, activation="sigmoid")(excitation)
99
- excitation = Reshape((1, 1, filters))(excitation)
100
- return Multiply()([x, excitation])
101
-
102
- def residual_block(x, filters):
103
- shortcut = x
104
- x = Conv2D(filters, (3, 3), padding="same", use_bias=False)(x)
105
- x = LayerNormalization()(x)
106
- x = LeakyReLU()(x)
107
 
108
- x = Conv2D(filters, (3, 3), padding="same", use_bias=False)(x)
109
- x = LayerNormalization()(x)
110
- x = se_block(x)
111
-
112
- if shortcut.shape[-1] != filters:
113
- shortcut = Conv2D(filters, (1, 1), padding="same", use_bias=False)(shortcut)
114
- shortcut = LayerNormalization()(shortcut)
115
-
116
- x = Add()([x, shortcut])
117
- x = LeakyReLU()(x)
118
- x = SpatialDropout2D(0.2)(x)
119
- return x
120
 
121
  class EndpointHandler:
122
  def __init__(self, model_dir):
@@ -128,75 +36,123 @@ class EndpointHandler:
128
  logger.info(f"Using provided model directory: {model_dir}")
129
 
130
  # Load the model
131
- model_path = os.path.join(model_dir, "model/speechModelv2.keras")
132
  logger.info(f"Loading model from: {model_path}")
133
 
134
  try:
135
- # Load the model with custom objects
136
- custom_objects = {
137
- "AudioPreprocessingLayer": AudioPreprocessingLayer,
138
- "InputLayer": CustomInputLayer
139
- }
140
- self.model = load_model(model_path, custom_objects=custom_objects)
141
- logger.info(f"Model loaded successfully with input shape: {self.model.input_shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  except Exception as e:
144
- logger.error(f"Failed to initialize endpoint: {str(e)}", exc_info=True)
145
  raise
146
 
147
  def __call__(self, requests):
148
  start_time = time.time()
149
- logger.info("Processing speech recognition request")
150
  temp_dir = None
151
  temp_wav_path = None
 
152
 
153
  try:
154
- # Extract input audio bytes
155
- input_audio = requests.get('inputs', None)
156
- if input_audio is None:
157
- logger.error("No input data provided")
158
- return [{"error": "No input data provided"}]
159
 
160
- if not isinstance(input_audio, bytes):
161
- logger.error(f"Expected bytes input, got {type(input_audio)}")
162
- return [{"error": f"Invalid input type: {type(input_audio)}, expected bytes"}]
163
 
164
- # Create temporary file for audio processing
165
  temp_dir = tempfile.mkdtemp()
166
- temp_wav_path = os.path.join(temp_dir, "speech_input.wav")
167
  logger.info(f"Created temporary directory: {temp_dir}")
168
 
169
- # Write audio bytes to temporary file
170
- logger.debug(f"Writing {len(input_audio)} bytes to temporary file: {temp_wav_path}")
171
  with open(temp_wav_path, "wb") as f:
172
- f.write(input_audio)
173
 
 
174
  if not os.path.exists(temp_wav_path):
175
  logger.error(f"Failed to create temporary WAV file: {temp_wav_path}")
176
  return [{"error": "Failed to create temporary WAV file"}]
177
 
178
- logger.debug(f"File size: {os.path.getsize(temp_wav_path)} bytes")
 
 
 
 
 
 
 
 
179
 
180
- # Preprocess and run inference
181
- inputs = tf.constant([temp_wav_path])
182
  logger.info("Running model prediction")
183
-
184
- predictions = self.model.predict(inputs)
185
  logger.debug(f"Raw predictions shape: {predictions.shape}")
186
 
187
  # Process results
188
  results = []
189
  for i, prediction in enumerate(predictions):
190
- # Get top 3 predictions
191
- top_indices = np.argsort(prediction)[-3:][::-1]
 
 
 
 
 
192
 
193
- results.append({
194
- "prediction": int(top_indices[0]),
195
- "confidence": float(prediction[top_indices[0]])
196
- })
197
 
198
  elapsed_time = time.time() - start_time
199
- logger.info(f"Speech recognition completed in {elapsed_time:.3f} seconds")
200
  return results
201
 
202
  except Exception as e:
 
1
  import tensorflow as tf
2
  import numpy as np
3
  import os
4
+ import librosa
5
  import tempfile
6
  import logging
7
  import time
 
8
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
9
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
10
 
11
  from tensorflow.keras.models import load_model
 
 
 
 
 
12
 
13
  # Configure logging
14
  logging.basicConfig(
 
18
  logging.StreamHandler()
19
  ]
20
  )
21
+ logger = logging.getLogger('speech_recognition_inference')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Constants for audio preprocessing
24
+ SAMPLE_RATE = 16000
25
+ N_MELS = 128
26
+ FFT_SIZE = 1024
27
+ HOP_SIZE = 512
 
 
 
 
 
 
 
28
 
29
  class EndpointHandler:
30
  def __init__(self, model_dir):
 
36
  logger.info(f"Using provided model directory: {model_dir}")
37
 
38
  # Load the model
39
+ model_path = os.path.join(model_dir, "model/speech_modelv2.keras")
40
  logger.info(f"Loading model from: {model_path}")
41
 
42
  try:
43
+ self.model = load_model(model_path)
44
+ logger.info(f"Model loaded successfully")
45
+ logger.debug(f"Model summary: {self.model.summary()}")
46
+ except Exception as e:
47
+ logger.error(f"Failed to load model: {str(e)}")
48
+ raise
49
+
50
+ def preprocess_audio(self, file_path):
51
+ """
52
+ Process audio file to match the training preprocessing exactly
53
+ """
54
+ logger.debug(f"Processing audio file: {file_path}")
55
+ try:
56
+ # Load audio using librosa (same as training)
57
+ audio, sr = librosa.load(file_path, sr=SAMPLE_RATE)
58
+
59
+ # Convert to Mel spectrogram (matching training parameters)
60
+ mel_spectrogram = librosa.feature.melspectrogram(
61
+ y=audio,
62
+ sr=sr,
63
+ n_mels=N_MELS,
64
+ n_fft=FFT_SIZE,
65
+ hop_length=HOP_SIZE
66
+ )
67
+ log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
68
+
69
+ # Ensure fixed size (128x128)
70
+ if log_mel_spectrogram.shape[1] < 128:
71
+ log_mel_spectrogram = np.pad(
72
+ log_mel_spectrogram,
73
+ ((0, 0), (0, 128 - log_mel_spectrogram.shape[1])),
74
+ mode='constant'
75
+ )
76
+ else:
77
+ log_mel_spectrogram = log_mel_spectrogram[:, :128]
78
+
79
+ # Expand dimensions for CNN input (128x128x1)
80
+ mel_spectrogram_processed = np.expand_dims(log_mel_spectrogram, axis=-1)
81
+
82
+ # Convert to RGB by duplicating channels (128x128x3)
83
+ # Matching the model's expectation of RGB input
84
+ mel_spectrogram_rgb = np.repeat(mel_spectrogram_processed, 3, axis=2)
85
+
86
+ logger.debug(f"Final mel spectrogram shape: {mel_spectrogram_rgb.shape}")
87
+ return mel_spectrogram_rgb
88
 
89
  except Exception as e:
90
+ logger.error(f"Error in preprocess_audio: {str(e)}")
91
  raise
92
 
93
  def __call__(self, requests):
94
  start_time = time.time()
95
+ logger.info("Processing speech recognition inference request")
96
  temp_dir = None
97
  temp_wav_path = None
98
+ audio_data = requests.get('inputs', None)
99
 
100
  try:
101
+ # Validate input
102
+ if not audio_data:
103
+ logger.error("No 'inputs' field found in the request")
104
+ return [{"error": "No audio data provided in 'inputs' field"}]
 
105
 
106
+ if not isinstance(audio_data, bytes):
107
+ logger.error(f"Expected bytes, got {type(audio_data)}")
108
+ return [{"error": f"Invalid input type: {type(audio_data)}, expected bytes"}]
109
 
110
+ # Create temporary file for the audio
111
  temp_dir = tempfile.mkdtemp()
112
+ temp_wav_path = os.path.join(temp_dir, "wav_input.wav")
113
  logger.info(f"Created temporary directory: {temp_dir}")
114
 
115
+ # Write audio data to file
116
+ logger.debug(f"Writing {len(audio_data)} bytes to temporary file: {temp_wav_path}")
117
  with open(temp_wav_path, "wb") as f:
118
+ f.write(audio_data)
119
 
120
+ # Verify file was created
121
  if not os.path.exists(temp_wav_path):
122
  logger.error(f"Failed to create temporary WAV file: {temp_wav_path}")
123
  return [{"error": "Failed to create temporary WAV file"}]
124
 
125
+ # Preprocess audio
126
+ logger.info("Preprocessing audio")
127
+ try:
128
+ preprocessed_audio = self.preprocess_audio(temp_wav_path)
129
+ # Add batch dimension
130
+ preprocessed_input = np.expand_dims(preprocessed_audio, axis=0)
131
+ except Exception as e:
132
+ logger.error(f"Error during preprocessing: {str(e)}")
133
+ return [{"error": f"Preprocessing failed: {str(e)}"}]
134
 
135
+ # Run prediction
 
136
  logger.info("Running model prediction")
137
+ predictions = self.model.predict(preprocessed_input)
 
138
  logger.debug(f"Raw predictions shape: {predictions.shape}")
139
 
140
  # Process results
141
  results = []
142
  for i, prediction in enumerate(predictions):
143
+ predicted_class_index = int(np.argmax(prediction))
144
+ confidence = float(prediction[predicted_class_index])
145
+
146
+ result = {
147
+ "word": predicted_class_index,
148
+ "confidence": confidence
149
+ }
150
 
151
+ logger.info(f"Result {i}: class={predicted_class_index}, confidence={confidence:.4f}")
152
+ results.append(result)
 
 
153
 
154
  elapsed_time = time.time() - start_time
155
+ logger.info(f"Inference completed in {elapsed_time:.3f} seconds")
156
  return results
157
 
158
  except Exception as e: