respiratory / inference.py
h3rsh's picture
Update inference.py
6a1466b verified
import os
import json
import numpy as np
import librosa
import pickle
import tensorflow as tf
from scipy import signal
import warnings
import tempfile
import base64
from typing import Dict, List, Any, Union
from io import BytesIO
import soundfile as sf
warnings.filterwarnings("ignore", message="Trying to estimate tuning from empty frequency set.")
# Common parameters (must match training parameters)
target_sr = 22050
target_duration = 4
n_fft = 512
hop_length = 512
class RespiratoryPredictor:
def __init__(self):
"""Initialize the predictor with trained model and scalers."""
self.target_sr = target_sr
self.target_duration = target_duration
self.n_fft = n_fft
self.hop_length = hop_length
# Load model with multiple fallback methods
model_loaded = False
model_path = 'respiratory_model.keras'
# Method 1: Try .keras format
if os.path.exists(model_path) and not model_loaded:
try:
self.model = tf.keras.models.load_model(model_path, compile=False)
print(f"Model loaded from .keras format: {model_path}")
model_loaded = True
except Exception as e:
print(f"Failed to load .keras format: {e}")
# Method 2: Try TensorFlow SavedModel format
tf_model_path = model_path.replace('.keras', '_tf')
if os.path.exists(tf_model_path) and not model_loaded:
try:
self.model = tf.keras.models.load_model(tf_model_path)
print(f"Model loaded from TF SavedModel format: {tf_model_path}")
model_loaded = True
except Exception as e:
print(f"Failed to load TF SavedModel format: {e}")
if not model_loaded:
raise RuntimeError("Failed to load model with any available method")
# Load scalers
try:
with open('scalers.pkl', 'rb') as f:
self.scalers = pickle.load(f)
print("Scalers loaded successfully")
except Exception as e:
print(f"Error loading scalers: {e}")
raise
# Load normalization parameters
try:
with open('norm_params.pkl', 'rb') as f:
self.norm_params = pickle.load(f)
print("Normalization parameters loaded successfully")
except Exception as e:
print(f"Error loading normalization parameters: {e}")
raise
# Load class names
try:
with open('class_names.pkl', 'rb') as f:
self.class_names = pickle.load(f)
print(f"Class names loaded: {self.class_names}")
except Exception as e:
print(f"Error loading class names: {e}")
raise
def denoise_audio(self, audio, sr, methods=['adaptive_median', 'bandpass']):
"""Denoise audio signal"""
denoised_audio = audio.copy()
for method in methods:
if method == 'adaptive_median':
window_size = int(sr * 0.01) # 10 ms window
if window_size % 2 == 0:
window_size += 1
denoised_audio = signal.medfilt(denoised_audio, kernel_size=window_size)
elif method == 'bandpass':
low_freq = 50
high_freq = 2000
nyquist = sr / 2
low = low_freq / nyquist
high = high_freq / nyquist
b, a = signal.butter(4, [low, high], btype='band')
denoised_audio = signal.filtfilt(b, a, denoised_audio)
return denoised_audio
def extract_features(self, audio_data, sr):
"""Extract features from audio in the same format as during training"""
# Mel spectrogram
mel_spec = librosa.feature.melspectrogram(
y=audio_data, sr=sr, n_mels=128, n_fft=self.n_fft, hop_length=self.hop_length)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
# MFCC
mfcc = librosa.feature.mfcc(y=audio_data, sr=sr, n_mfcc=20, hop_length=self.hop_length)
# Chroma
chroma = librosa.feature.chroma_stft(y=audio_data, sr=sr, hop_length=self.hop_length)
features = {
'mel_spec': mel_spec_db,
'mfcc': mfcc,
'chroma': chroma
}
return features
def pad_or_crop(self, arr, shape):
"""Pad or crop array to target shape"""
out = np.zeros(shape, dtype=arr.dtype)
n_feat, n_fr = arr.shape
out[:min(n_feat, shape[0]), :min(n_fr, shape[1])] = arr[:shape[0], :shape[1]]
return out
def prepare_input_data(self, features, n_frames=259):
"""Prepare input data for the multi-input model"""
mfcc = self.pad_or_crop(features['mfcc'], (20, n_frames))
chroma = self.pad_or_crop(features['chroma'], (12, n_frames))
mspec = self.pad_or_crop(features['mel_spec'], (128, n_frames))
# Add channel dimension
X_mfcc = mfcc[..., np.newaxis]
X_chroma = chroma[..., np.newaxis]
X_mspec = mspec[..., np.newaxis]
return X_mfcc, X_chroma, X_mspec
def normalize_features(self, X_mfcc, X_chroma, X_mspec):
"""Normalize features using the same parameters as training"""
def norm(X, mean, std):
Xf = X.reshape(X.shape[0], -1)
Xn = (Xf - mean) / (std + 1e-8)
return Xn.reshape(X.shape)
X_mfcc_norm = norm(X_mfcc, self.norm_params['mfcc_mean'], self.norm_params['mfcc_std'])
X_chroma_norm = norm(X_chroma, self.norm_params['chroma_mean'], self.norm_params['chroma_std'])
X_mspec_norm = norm(X_mspec, self.norm_params['mspec_mean'], self.norm_params['mspec_std'])
return X_mfcc_norm, X_chroma_norm, X_mspec_norm
def process_audio_from_bytes(self, audio_bytes: bytes) -> np.ndarray:
"""Process audio from raw bytes data."""
try:
# Create a temporary file to write the audio bytes
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
temp_file.write(audio_bytes)
temp_file_path = temp_file.name
# Load audio using librosa
audio, sr = librosa.load(temp_file_path, sr=self.target_sr, duration=self.target_duration)
# Clean up temporary file
os.unlink(temp_file_path)
return audio
except Exception as e:
# Fallback: try to read directly with soundfile
try:
audio_io = BytesIO(audio_bytes)
audio, sr = sf.read(audio_io)
# Resample if necessary
if sr != self.target_sr:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.target_sr)
# Ensure mono
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Crop to target duration
target_samples = int(self.target_sr * self.target_duration)
if len(audio) > target_samples:
audio = audio[:target_samples]
return audio
except Exception as e2:
raise Exception(f"Failed to process audio: {str(e)}, {str(e2)}")
def predict(self, audio_input: Union[str, bytes, np.ndarray]) -> Dict[str, Any]:
"""Make prediction on audio input."""
try:
# Handle different input types
if isinstance(audio_input, str):
# Assume it's base64 encoded
audio_bytes = base64.b64decode(audio_input)
audio = self.process_audio_from_bytes(audio_bytes)
elif isinstance(audio_input, bytes):
audio = self.process_audio_from_bytes(audio_input)
elif isinstance(audio_input, np.ndarray):
audio = audio_input
else:
raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
# Ensure audio is the right length
target_samples = self.target_sr * self.target_duration
if len(audio) < target_samples:
audio = np.pad(audio, (0, target_samples - len(audio)), mode='constant')
elif len(audio) > target_samples:
audio = audio[:target_samples]
# Denoise audio
denoised_audio = self.denoise_audio(audio, self.target_sr)
# Extract features
features = self.extract_features(denoised_audio, self.target_sr)
# Prepare input data
X_mfcc, X_chroma, X_mspec = self.prepare_input_data(features)
# Normalize features
X_mfcc_norm, X_chroma_norm, X_mspec_norm = self.normalize_features(X_mfcc, X_chroma, X_mspec)
# Add batch dimension
X_mfcc_batch = np.expand_dims(X_mfcc_norm, axis=0)
X_chroma_batch = np.expand_dims(X_chroma_norm, axis=0)
X_mspec_batch = np.expand_dims(X_mspec_norm, axis=0)
# Make prediction
prediction_prob = self.model.predict([X_mfcc_batch, X_chroma_batch, X_mspec_batch], verbose=0)
prediction = int(np.argmax(prediction_prob[0]))
confidence = float(np.max(prediction_prob[0]))
# Get class name
class_name = self.class_names[prediction] if prediction < len(self.class_names) else f"Class {prediction}"
# Create probabilities dictionary
probabilities = {}
for i, (cls_name, prob) in enumerate(zip(self.class_names, prediction_prob[0])):
probabilities[cls_name] = float(prob)
return {
"label": class_name,
"score": confidence,
"probabilities": probabilities
}
except Exception as e:
return {
"error": str(e),
"label": None,
"score": 0.0
}
# Global predictor instance
_predictor = None
def pipeline(inputs: Union[str, bytes, Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Hugging Face pipeline function for respiratory sound classification.
Args:
inputs: Can be:
- Base64 encoded audio string
- Raw audio bytes
- Dictionary with 'inputs' key containing audio data
Returns:
List of prediction dictionaries
"""
global _predictor
# Initialize predictor if not already done
if _predictor is None:
print("Initializing respiratory sound predictor...")
_predictor = RespiratoryPredictor()
print("Predictor initialized successfully!")
try:
# Handle different input formats
if isinstance(inputs, dict):
# Extract audio from inputs dict
audio_data = inputs.get('inputs', inputs.get('audio', ''))
else:
audio_data = inputs
if not audio_data:
return [{"error": "No audio data provided", "label": None, "score": 0.0}]
# Make prediction
result = _predictor.predict(audio_data)
# Return as list (Hugging Face expects list format)
return [result]
except Exception as e:
return [{"error": str(e), "label": None, "score": 0.0}]
# For testing locally
if __name__ == "__main__":
# Test the pipeline function
print("Testing pipeline function...")
# This would normally be called by Hugging Face infrastructure
# For testing, you would need actual audio data
test_result = pipeline("")
print(f"Pipeline ready! Test result: {test_result}")