h3rsh commited on
Commit
6a1466b
·
verified ·
1 Parent(s): 0a63270

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +145 -113
inference.py CHANGED
@@ -1,12 +1,16 @@
1
  import os
 
2
  import numpy as np
3
  import librosa
4
  import pickle
5
  import tensorflow as tf
6
- import gradio as gr
7
  from scipy import signal
8
  import warnings
9
  import tempfile
 
 
 
 
10
 
11
  warnings.filterwarnings("ignore", message="Trying to estimate tuning from empty frequency set.")
12
 
@@ -17,47 +21,64 @@ n_fft = 512
17
  hop_length = 512
18
 
19
  class RespiratoryPredictor:
20
- def __init__(self, model_path='respiratory_model.keras', scalers_path='scalers.pkl',
21
- norm_params_path='norm_params.pkl', class_names_path='class_names.pkl'):
22
  """Initialize the predictor with trained model and scalers."""
23
  self.target_sr = target_sr
24
  self.target_duration = target_duration
25
  self.n_fft = n_fft
26
  self.hop_length = hop_length
27
 
28
- # Load model
29
- try:
30
- self.model = tf.keras.models.load_model(model_path)
31
- print(f"✓ Model loaded from {model_path}")
32
- except Exception as e:
33
- print(f"✗ Error loading model: {e}")
34
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Load scalers
37
  try:
38
- with open(scalers_path, 'rb') as f:
39
  self.scalers = pickle.load(f)
40
- print(f"Scalers loaded from {scalers_path}")
41
  except Exception as e:
42
- print(f"Error loading scalers: {e}")
43
  raise
44
 
45
  # Load normalization parameters
46
  try:
47
- with open(norm_params_path, 'rb') as f:
48
  self.norm_params = pickle.load(f)
49
- print(f"Normalization parameters loaded from {norm_params_path}")
50
  except Exception as e:
51
- print(f"Error loading normalization parameters: {e}")
52
  raise
53
 
54
  # Load class names
55
  try:
56
- with open(class_names_path, 'rb') as f:
57
  self.class_names = pickle.load(f)
58
- print(f"Class names loaded from {class_names_path}")
59
  except Exception as e:
60
- print(f"Error loading class names: {e}")
61
  raise
62
 
63
  def denoise_audio(self, audio, sr, methods=['adaptive_median', 'bandpass']):
@@ -134,19 +155,59 @@ class RespiratoryPredictor:
134
 
135
  return X_mfcc_norm, X_chroma_norm, X_mspec_norm
136
 
137
- def predict_audio(self, audio_file_path):
138
- """
139
- Predict the class of an audio file for Gradio interface.
140
-
141
- Args:
142
- audio_file_path: Path to the uploaded audio file
 
 
 
 
 
 
 
143
 
144
- Returns:
145
- tuple: (prediction_text, confidence_text, probabilities_dict)
146
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  try:
148
- # Load and process audio
149
- audio, sr = librosa.load(audio_file_path, sr=self.target_sr, duration=self.target_duration)
 
 
 
 
 
 
 
 
 
150
 
151
  # Ensure audio is the right length
152
  target_samples = self.target_sr * self.target_duration
@@ -180,103 +241,74 @@ class RespiratoryPredictor:
180
  # Get class name
181
  class_name = self.class_names[prediction] if prediction < len(self.class_names) else f"Class {prediction}"
182
 
183
- # Format results for Gradio
184
- prediction_text = f"**Prediction**: {class_name}"
185
- confidence_text = f"**Confidence**: {confidence:.2%}"
 
186
 
187
- # Create probabilities dictionary for all classes
188
- probabilities_dict = {}
189
- for i, (class_name_item, prob) in enumerate(zip(self.class_names, prediction_prob[0])):
190
- probabilities_dict[class_name_item] = float(prob)
191
-
192
- return prediction_text, confidence_text, probabilities_dict
193
 
194
  except Exception as e:
195
- error_msg = f"Error processing audio: {str(e)}"
196
- return error_msg, "", {}
 
 
 
197
 
198
- # Initialize the predictor
199
- print("Loading model and components...")
200
- try:
201
- predictor = RespiratoryPredictor()
202
- print("All components loaded successfully!")
203
- except Exception as e:
204
- print(f"Failed to initialize predictor: {e}")
205
- raise
206
 
207
- def predict_respiratory_sound(audio_file):
208
  """
209
- Gradio interface function for respiratory sound prediction.
210
 
211
  Args:
212
- audio_file: Uploaded audio file from Gradio
213
-
 
 
 
214
  Returns:
215
- tuple: (prediction, confidence, probabilities)
216
  """
217
- if audio_file is None:
218
- return "Please upload an audio file", "", {}
219
 
220
- return predictor.predict_audio(audio_file)
221
-
222
- # Create Gradio interface
223
- with gr.Blocks(title="Respiratory Sound Classifier", theme=gr.themes.Soft()) as demo:
224
- gr.Markdown(
225
- """
226
- # Respiratory Sound Classification
227
-
228
- Upload an audio file containing respiratory sounds to classify the type of breathing pattern.
229
-
230
- **Supported formats**: WAV, MP3, M4A, FLAC
231
- **Duration**: Audio will be processed as 4-second segments
232
- """
233
- )
234
 
235
- with gr.Row():
236
- with gr.Column():
237
- audio_input = gr.Audio(
238
- label="Upload Respiratory Sound",
239
- type="filepath",
240
- sources=["upload"]
241
- )
242
-
243
- predict_btn = gr.Button("🔍 Analyze Sound", variant="primary")
244
 
245
- with gr.Column():
246
- prediction_output = gr.Markdown(label="Prediction")
247
- confidence_output = gr.Markdown(label="Confidence")
248
-
249
- probabilities_output = gr.Label(
250
- label="Class Probabilities",
251
- num_top_classes=len(predictor.class_names)
252
- )
253
-
254
- # Event handlers
255
- predict_btn.click(
256
- fn=predict_respiratory_sound,
257
- inputs=[audio_input],
258
- outputs=[prediction_output, confidence_output, probabilities_output]
259
- )
260
-
261
- # Auto-predict when file is uploaded
262
- audio_input.change(
263
- fn=predict_respiratory_sound,
264
- inputs=[audio_input],
265
- outputs=[prediction_output, confidence_output, probabilities_output]
266
- )
267
-
268
- gr.Markdown(
269
- """
270
- ---
271
 
272
- ### About
273
- This model classifies respiratory sounds into different categories.
274
- Upload clear audio recordings of breathing sounds for best results.
275
 
276
- **Note**: This is for research/educational purposes only and should not be used for medical diagnosis.
277
- """
278
- )
279
 
280
- # Launch the app
281
  if __name__ == "__main__":
282
- demo.launch()
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
  import numpy as np
4
  import librosa
5
  import pickle
6
  import tensorflow as tf
 
7
  from scipy import signal
8
  import warnings
9
  import tempfile
10
+ import base64
11
+ from typing import Dict, List, Any, Union
12
+ from io import BytesIO
13
+ import soundfile as sf
14
 
15
  warnings.filterwarnings("ignore", message="Trying to estimate tuning from empty frequency set.")
16
 
 
21
  hop_length = 512
22
 
23
  class RespiratoryPredictor:
24
+ def __init__(self):
 
25
  """Initialize the predictor with trained model and scalers."""
26
  self.target_sr = target_sr
27
  self.target_duration = target_duration
28
  self.n_fft = n_fft
29
  self.hop_length = hop_length
30
 
31
+ # Load model with multiple fallback methods
32
+ model_loaded = False
33
+ model_path = 'respiratory_model.keras'
34
+
35
+ # Method 1: Try .keras format
36
+ if os.path.exists(model_path) and not model_loaded:
37
+ try:
38
+ self.model = tf.keras.models.load_model(model_path, compile=False)
39
+ print(f"Model loaded from .keras format: {model_path}")
40
+ model_loaded = True
41
+ except Exception as e:
42
+ print(f"Failed to load .keras format: {e}")
43
+
44
+ # Method 2: Try TensorFlow SavedModel format
45
+ tf_model_path = model_path.replace('.keras', '_tf')
46
+ if os.path.exists(tf_model_path) and not model_loaded:
47
+ try:
48
+ self.model = tf.keras.models.load_model(tf_model_path)
49
+ print(f"Model loaded from TF SavedModel format: {tf_model_path}")
50
+ model_loaded = True
51
+ except Exception as e:
52
+ print(f"Failed to load TF SavedModel format: {e}")
53
+
54
+ if not model_loaded:
55
+ raise RuntimeError("Failed to load model with any available method")
56
 
57
  # Load scalers
58
  try:
59
+ with open('scalers.pkl', 'rb') as f:
60
  self.scalers = pickle.load(f)
61
+ print("Scalers loaded successfully")
62
  except Exception as e:
63
+ print(f"Error loading scalers: {e}")
64
  raise
65
 
66
  # Load normalization parameters
67
  try:
68
+ with open('norm_params.pkl', 'rb') as f:
69
  self.norm_params = pickle.load(f)
70
+ print("Normalization parameters loaded successfully")
71
  except Exception as e:
72
+ print(f"Error loading normalization parameters: {e}")
73
  raise
74
 
75
  # Load class names
76
  try:
77
+ with open('class_names.pkl', 'rb') as f:
78
  self.class_names = pickle.load(f)
79
+ print(f"Class names loaded: {self.class_names}")
80
  except Exception as e:
81
+ print(f"Error loading class names: {e}")
82
  raise
83
 
84
  def denoise_audio(self, audio, sr, methods=['adaptive_median', 'bandpass']):
 
155
 
156
  return X_mfcc_norm, X_chroma_norm, X_mspec_norm
157
 
158
+ def process_audio_from_bytes(self, audio_bytes: bytes) -> np.ndarray:
159
+ """Process audio from raw bytes data."""
160
+ try:
161
+ # Create a temporary file to write the audio bytes
162
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
163
+ temp_file.write(audio_bytes)
164
+ temp_file_path = temp_file.name
165
+
166
+ # Load audio using librosa
167
+ audio, sr = librosa.load(temp_file_path, sr=self.target_sr, duration=self.target_duration)
168
+
169
+ # Clean up temporary file
170
+ os.unlink(temp_file_path)
171
 
172
+ return audio
173
+
174
+ except Exception as e:
175
+ # Fallback: try to read directly with soundfile
176
+ try:
177
+ audio_io = BytesIO(audio_bytes)
178
+ audio, sr = sf.read(audio_io)
179
+
180
+ # Resample if necessary
181
+ if sr != self.target_sr:
182
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=self.target_sr)
183
+
184
+ # Ensure mono
185
+ if len(audio.shape) > 1:
186
+ audio = np.mean(audio, axis=1)
187
+
188
+ # Crop to target duration
189
+ target_samples = int(self.target_sr * self.target_duration)
190
+ if len(audio) > target_samples:
191
+ audio = audio[:target_samples]
192
+
193
+ return audio
194
+ except Exception as e2:
195
+ raise Exception(f"Failed to process audio: {str(e)}, {str(e2)}")
196
+
197
+ def predict(self, audio_input: Union[str, bytes, np.ndarray]) -> Dict[str, Any]:
198
+ """Make prediction on audio input."""
199
  try:
200
+ # Handle different input types
201
+ if isinstance(audio_input, str):
202
+ # Assume it's base64 encoded
203
+ audio_bytes = base64.b64decode(audio_input)
204
+ audio = self.process_audio_from_bytes(audio_bytes)
205
+ elif isinstance(audio_input, bytes):
206
+ audio = self.process_audio_from_bytes(audio_input)
207
+ elif isinstance(audio_input, np.ndarray):
208
+ audio = audio_input
209
+ else:
210
+ raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
211
 
212
  # Ensure audio is the right length
213
  target_samples = self.target_sr * self.target_duration
 
241
  # Get class name
242
  class_name = self.class_names[prediction] if prediction < len(self.class_names) else f"Class {prediction}"
243
 
244
+ # Create probabilities dictionary
245
+ probabilities = {}
246
+ for i, (cls_name, prob) in enumerate(zip(self.class_names, prediction_prob[0])):
247
+ probabilities[cls_name] = float(prob)
248
 
249
+ return {
250
+ "label": class_name,
251
+ "score": confidence,
252
+ "probabilities": probabilities
253
+ }
 
254
 
255
  except Exception as e:
256
+ return {
257
+ "error": str(e),
258
+ "label": None,
259
+ "score": 0.0
260
+ }
261
 
262
+ # Global predictor instance
263
+ _predictor = None
 
 
 
 
 
 
264
 
265
+ def pipeline(inputs: Union[str, bytes, Dict[str, Any]]) -> List[Dict[str, Any]]:
266
  """
267
+ Hugging Face pipeline function for respiratory sound classification.
268
 
269
  Args:
270
+ inputs: Can be:
271
+ - Base64 encoded audio string
272
+ - Raw audio bytes
273
+ - Dictionary with 'inputs' key containing audio data
274
+
275
  Returns:
276
+ List of prediction dictionaries
277
  """
278
+ global _predictor
 
279
 
280
+ # Initialize predictor if not already done
281
+ if _predictor is None:
282
+ print("Initializing respiratory sound predictor...")
283
+ _predictor = RespiratoryPredictor()
284
+ print("Predictor initialized successfully!")
 
 
 
 
 
 
 
 
 
285
 
286
+ try:
287
+ # Handle different input formats
288
+ if isinstance(inputs, dict):
289
+ # Extract audio from inputs dict
290
+ audio_data = inputs.get('inputs', inputs.get('audio', ''))
291
+ else:
292
+ audio_data = inputs
 
 
293
 
294
+ if not audio_data:
295
+ return [{"error": "No audio data provided", "label": None, "score": 0.0}]
296
+
297
+ # Make prediction
298
+ result = _predictor.predict(audio_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ # Return as list (Hugging Face expects list format)
301
+ return [result]
 
302
 
303
+ except Exception as e:
304
+ return [{"error": str(e), "label": None, "score": 0.0}]
 
305
 
306
+ # For testing locally
307
  if __name__ == "__main__":
308
+ # Test the pipeline function
309
+ print("Testing pipeline function...")
310
+
311
+ # This would normally be called by Hugging Face infrastructure
312
+ # For testing, you would need actual audio data
313
+ test_result = pipeline("")
314
+ print(f"Pipeline ready! Test result: {test_result}")