new_audio / app.py
Anvit25's picture
Update app.py
45f9c09 verified
import os
import shutil
import gradio as gr
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import models
# ---------------- 1. Configuration ---------------- #
TEMP_DIR = "temp_gradio_specs"
os.makedirs(TEMP_DIR, exist_ok=True)
IMG_SIZE = (224, 224)
# ---------------- 2. Load Models ------------------ #
print("🚀 Loading machine learning models...")
try:
stage1_model = models.load_model("saved_models/stage1_model.h5")
abnormal_model = models.load_model("saved_models/abnormal_model.h5")
normal_model = models.load_model("saved_models/normal_model.h5")
print("✅ Models loaded successfully.")
except Exception as e:
print(f"❌ Error loading models: {e}")
# Do not exit—allows app to show error gracefully
stage1_model = abnormal_model = normal_model = None
# Default class lists – replace with actual labels if available
stage1_classes = ["00 - Abnormal", "01 - Normal"]
abnormal_classes = (
sorted(os.listdir("MelSpectrograms/00 - Abnormal"))
if os.path.exists("MelSpectrograms/00 - Abnormal")
else ["Bearing noise", "Dehydration mode noise"]
)
normal_classes = (
sorted(os.listdir("MelSpectrograms/01 - Normal"))
if os.path.exists("MelSpectrograms/01 - Normal")
else ["Wash mode", "Spin mode"]
)
print(f"Stage 1 Classes: {stage1_classes}")
print(f"Abnormal Sub-classes: {abnormal_classes}")
print(f"Normal Sub-classes: {normal_classes}")
# ---------------- 3. Helper Functions -------------- #
def save_mel_spectrogram(file_path, save_dir, sr=22050,
n_mels=128, hop_length=512, n_fft=2048):
"""Generates and saves a Mel Spectrogram from an audio file."""
try:
y, sr = librosa.load(file_path, sr=sr, mono=True)
S = librosa.feature.melspectrogram(
y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length
)
S_db = librosa.power_to_db(S, ref=np.max)
filename = os.path.basename(file_path).replace(".wav", ".png")
save_path = os.path.join(save_dir, filename)
plt.figure(figsize=(4, 4))
librosa.display.specshow(S_db, sr=sr, hop_length=hop_length,
x_axis="time", y_axis="mel", cmap="magma")
plt.axis("off")
plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
plt.close()
return save_path
except Exception as e:
print(f"❌ Error creating spectrogram: {e}")
return None
class HierarchicalClassifier:
"""A wrapper class for the two-stage prediction logic."""
def __init__(self, stage1_model, abnormal_model, normal_model,
stage1_classes, abnormal_classes, normal_classes):
self.img_size = IMG_SIZE
self.stage1_model = stage1_model
self.abnormal_model = abnormal_model
self.normal_model = normal_model
self.stage1_classes = stage1_classes
self.abnormal_classes = abnormal_classes
self.normal_classes = normal_classes
def _preprocess_image(self, image_path):
img = tf.keras.utils.load_img(image_path, target_size=self.img_size)
img_array = tf.keras.utils.img_to_array(img) / 255.0
return tf.expand_dims(img_array, 0)
def predict(self, image_path):
if not all([self.stage1_model, self.abnormal_model, self.normal_model]):
return {
"final_prediction": "❌ Models not loaded. Please upload models to /saved_models/",
"stage1_class": "N/A",
"stage1_confidence": 0,
"stage2_class": "N/A",
"stage2_confidence": 0
}
img_array = self._preprocess_image(image_path)
stage1_pred = self.stage1_model.predict(img_array, verbose=0)
stage1_idx = np.argmax(stage1_pred)
main_class = self.stage1_classes[stage1_idx]
if main_class == "00 - Abnormal":
sub_pred = self.abnormal_model.predict(img_array, verbose=0)
sub_idx = np.argmax(sub_pred)
sub_class = self.abnormal_classes[sub_idx]
else:
sub_pred = self.normal_model.predict(img_array, verbose=0)
sub_idx = np.argmax(sub_pred)
sub_class = self.normal_classes[sub_idx]
return {
"stage1_class": main_class,
"stage1_confidence": float(np.max(stage1_pred)),
"stage2_class": sub_class,
"stage2_confidence": float(np.max(sub_pred)),
"final_prediction": f"{main_class.split(' - ')[1]}{sub_class}"
}
classifier = HierarchicalClassifier(
stage1_model, abnormal_model, normal_model,
stage1_classes, abnormal_classes, normal_classes
)
# ---------------- 4. Prediction Function ----------- #
def predict_washing_machine_sound(audio_filepath):
if audio_filepath is None:
return "Please upload an audio file first.", None
print(f"Processing file: {audio_filepath}")
spec_path = save_mel_spectrogram(audio_filepath, TEMP_DIR)
if not spec_path:
return "❌ Could not generate spectrogram from the audio file.", None
result = classifier.predict(spec_path)
output_text = (
f"🎯 Final Prediction: {result['final_prediction']}\n\n"
f"Confidence Scores:\n"
f"--------------------\n"
f"Stage 1 ({result['stage1_class']}): {result['stage1_confidence']:.4f}\n"
f"Stage 2 ({result['stage2_class']}): {result['stage2_confidence']:.4f}"
)
return output_text, spec_path
# ---------------- 5. Gradio Interface -------------- #
if __name__ == "__main__":
demo = gr.Interface(
fn=predict_washing_machine_sound,
inputs=gr.Audio(type="filepath", label="Upload Washing-Machine Audio (.wav)"),
outputs=[
gr.Textbox(label="Prediction Result"),
gr.Image(label="Generated Mel-Spectrogram")
],
title="Washing-Machine Sound Classifier",
description="Upload a WAV file of washing-machine audio to classify its operation status.",
allow_flagging="never",
# examples=[] # ← removed local file examples
)
demo.launch()
# Cleanup temp dir after app stops
try:
shutil.rmtree(TEMP_DIR)
print("✅ Cleaned up temporary files.")
except Exception as e:
print(f"⚠️ Cleanup warning: {e}")