import os import shutil import gradio as gr import numpy as np import librosa import librosa.display import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import models # ---------------- 1. Configuration ---------------- # TEMP_DIR = "temp_gradio_specs" os.makedirs(TEMP_DIR, exist_ok=True) IMG_SIZE = (224, 224) # ---------------- 2. Load Models ------------------ # print("πŸš€ Loading machine learning models...") try: stage1_model = models.load_model("saved_models/stage1_model.h5") abnormal_model = models.load_model("saved_models/abnormal_model.h5") normal_model = models.load_model("saved_models/normal_model.h5") print("βœ… Models loaded successfully.") except Exception as e: print(f"❌ Error loading models: {e}") # Do not exitβ€”allows app to show error gracefully stage1_model = abnormal_model = normal_model = None # Default class lists – replace with actual labels if available stage1_classes = ["00 - Abnormal", "01 - Normal"] abnormal_classes = ( sorted(os.listdir("MelSpectrograms/00 - Abnormal")) if os.path.exists("MelSpectrograms/00 - Abnormal") else ["Bearing noise", "Dehydration mode noise"] ) normal_classes = ( sorted(os.listdir("MelSpectrograms/01 - Normal")) if os.path.exists("MelSpectrograms/01 - Normal") else ["Wash mode", "Spin mode"] ) print(f"Stage 1 Classes: {stage1_classes}") print(f"Abnormal Sub-classes: {abnormal_classes}") print(f"Normal Sub-classes: {normal_classes}") # ---------------- 3. Helper Functions -------------- # def save_mel_spectrogram(file_path, save_dir, sr=22050, n_mels=128, hop_length=512, n_fft=2048): """Generates and saves a Mel Spectrogram from an audio file.""" try: y, sr = librosa.load(file_path, sr=sr, mono=True) S = librosa.feature.melspectrogram( y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length ) S_db = librosa.power_to_db(S, ref=np.max) filename = os.path.basename(file_path).replace(".wav", ".png") save_path = os.path.join(save_dir, filename) plt.figure(figsize=(4, 4)) librosa.display.specshow(S_db, sr=sr, hop_length=hop_length, x_axis="time", y_axis="mel", cmap="magma") plt.axis("off") plt.savefig(save_path, bbox_inches="tight", pad_inches=0) plt.close() return save_path except Exception as e: print(f"❌ Error creating spectrogram: {e}") return None class HierarchicalClassifier: """A wrapper class for the two-stage prediction logic.""" def __init__(self, stage1_model, abnormal_model, normal_model, stage1_classes, abnormal_classes, normal_classes): self.img_size = IMG_SIZE self.stage1_model = stage1_model self.abnormal_model = abnormal_model self.normal_model = normal_model self.stage1_classes = stage1_classes self.abnormal_classes = abnormal_classes self.normal_classes = normal_classes def _preprocess_image(self, image_path): img = tf.keras.utils.load_img(image_path, target_size=self.img_size) img_array = tf.keras.utils.img_to_array(img) / 255.0 return tf.expand_dims(img_array, 0) def predict(self, image_path): if not all([self.stage1_model, self.abnormal_model, self.normal_model]): return { "final_prediction": "❌ Models not loaded. Please upload models to /saved_models/", "stage1_class": "N/A", "stage1_confidence": 0, "stage2_class": "N/A", "stage2_confidence": 0 } img_array = self._preprocess_image(image_path) stage1_pred = self.stage1_model.predict(img_array, verbose=0) stage1_idx = np.argmax(stage1_pred) main_class = self.stage1_classes[stage1_idx] if main_class == "00 - Abnormal": sub_pred = self.abnormal_model.predict(img_array, verbose=0) sub_idx = np.argmax(sub_pred) sub_class = self.abnormal_classes[sub_idx] else: sub_pred = self.normal_model.predict(img_array, verbose=0) sub_idx = np.argmax(sub_pred) sub_class = self.normal_classes[sub_idx] return { "stage1_class": main_class, "stage1_confidence": float(np.max(stage1_pred)), "stage2_class": sub_class, "stage2_confidence": float(np.max(sub_pred)), "final_prediction": f"{main_class.split(' - ')[1]} β†’ {sub_class}" } classifier = HierarchicalClassifier( stage1_model, abnormal_model, normal_model, stage1_classes, abnormal_classes, normal_classes ) # ---------------- 4. Prediction Function ----------- # def predict_washing_machine_sound(audio_filepath): if audio_filepath is None: return "Please upload an audio file first.", None print(f"Processing file: {audio_filepath}") spec_path = save_mel_spectrogram(audio_filepath, TEMP_DIR) if not spec_path: return "❌ Could not generate spectrogram from the audio file.", None result = classifier.predict(spec_path) output_text = ( f"🎯 Final Prediction: {result['final_prediction']}\n\n" f"Confidence Scores:\n" f"--------------------\n" f"Stage 1 ({result['stage1_class']}): {result['stage1_confidence']:.4f}\n" f"Stage 2 ({result['stage2_class']}): {result['stage2_confidence']:.4f}" ) return output_text, spec_path # ---------------- 5. Gradio Interface -------------- # if __name__ == "__main__": demo = gr.Interface( fn=predict_washing_machine_sound, inputs=gr.Audio(type="filepath", label="Upload Washing-Machine Audio (.wav)"), outputs=[ gr.Textbox(label="Prediction Result"), gr.Image(label="Generated Mel-Spectrogram") ], title="Washing-Machine Sound Classifier", description="Upload a WAV file of washing-machine audio to classify its operation status.", allow_flagging="never", # examples=[] # ← removed local file examples ) demo.launch() # Cleanup temp dir after app stops try: shutil.rmtree(TEMP_DIR) print("βœ… Cleaned up temporary files.") except Exception as e: print(f"⚠️ Cleanup warning: {e}")