Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import gradio as gr | |
| import numpy as np | |
| import librosa | |
| import librosa.display | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| from tensorflow.keras import models | |
| # ---------------- 1. Configuration ---------------- # | |
| TEMP_DIR = "temp_gradio_specs" | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| IMG_SIZE = (224, 224) | |
| # ---------------- 2. Load Models ------------------ # | |
| print("🚀 Loading machine learning models...") | |
| try: | |
| stage1_model = models.load_model("saved_models/stage1_model.h5") | |
| abnormal_model = models.load_model("saved_models/abnormal_model.h5") | |
| normal_model = models.load_model("saved_models/normal_model.h5") | |
| print("✅ Models loaded successfully.") | |
| except Exception as e: | |
| print(f"❌ Error loading models: {e}") | |
| # Do not exit—allows app to show error gracefully | |
| stage1_model = abnormal_model = normal_model = None | |
| # Default class lists – replace with actual labels if available | |
| stage1_classes = ["00 - Abnormal", "01 - Normal"] | |
| abnormal_classes = ( | |
| sorted(os.listdir("MelSpectrograms/00 - Abnormal")) | |
| if os.path.exists("MelSpectrograms/00 - Abnormal") | |
| else ["Bearing noise", "Dehydration mode noise"] | |
| ) | |
| normal_classes = ( | |
| sorted(os.listdir("MelSpectrograms/01 - Normal")) | |
| if os.path.exists("MelSpectrograms/01 - Normal") | |
| else ["Wash mode", "Spin mode"] | |
| ) | |
| print(f"Stage 1 Classes: {stage1_classes}") | |
| print(f"Abnormal Sub-classes: {abnormal_classes}") | |
| print(f"Normal Sub-classes: {normal_classes}") | |
| # ---------------- 3. Helper Functions -------------- # | |
| def save_mel_spectrogram(file_path, save_dir, sr=22050, | |
| n_mels=128, hop_length=512, n_fft=2048): | |
| """Generates and saves a Mel Spectrogram from an audio file.""" | |
| try: | |
| y, sr = librosa.load(file_path, sr=sr, mono=True) | |
| S = librosa.feature.melspectrogram( | |
| y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length | |
| ) | |
| S_db = librosa.power_to_db(S, ref=np.max) | |
| filename = os.path.basename(file_path).replace(".wav", ".png") | |
| save_path = os.path.join(save_dir, filename) | |
| plt.figure(figsize=(4, 4)) | |
| librosa.display.specshow(S_db, sr=sr, hop_length=hop_length, | |
| x_axis="time", y_axis="mel", cmap="magma") | |
| plt.axis("off") | |
| plt.savefig(save_path, bbox_inches="tight", pad_inches=0) | |
| plt.close() | |
| return save_path | |
| except Exception as e: | |
| print(f"❌ Error creating spectrogram: {e}") | |
| return None | |
| class HierarchicalClassifier: | |
| """A wrapper class for the two-stage prediction logic.""" | |
| def __init__(self, stage1_model, abnormal_model, normal_model, | |
| stage1_classes, abnormal_classes, normal_classes): | |
| self.img_size = IMG_SIZE | |
| self.stage1_model = stage1_model | |
| self.abnormal_model = abnormal_model | |
| self.normal_model = normal_model | |
| self.stage1_classes = stage1_classes | |
| self.abnormal_classes = abnormal_classes | |
| self.normal_classes = normal_classes | |
| def _preprocess_image(self, image_path): | |
| img = tf.keras.utils.load_img(image_path, target_size=self.img_size) | |
| img_array = tf.keras.utils.img_to_array(img) / 255.0 | |
| return tf.expand_dims(img_array, 0) | |
| def predict(self, image_path): | |
| if not all([self.stage1_model, self.abnormal_model, self.normal_model]): | |
| return { | |
| "final_prediction": "❌ Models not loaded. Please upload models to /saved_models/", | |
| "stage1_class": "N/A", | |
| "stage1_confidence": 0, | |
| "stage2_class": "N/A", | |
| "stage2_confidence": 0 | |
| } | |
| img_array = self._preprocess_image(image_path) | |
| stage1_pred = self.stage1_model.predict(img_array, verbose=0) | |
| stage1_idx = np.argmax(stage1_pred) | |
| main_class = self.stage1_classes[stage1_idx] | |
| if main_class == "00 - Abnormal": | |
| sub_pred = self.abnormal_model.predict(img_array, verbose=0) | |
| sub_idx = np.argmax(sub_pred) | |
| sub_class = self.abnormal_classes[sub_idx] | |
| else: | |
| sub_pred = self.normal_model.predict(img_array, verbose=0) | |
| sub_idx = np.argmax(sub_pred) | |
| sub_class = self.normal_classes[sub_idx] | |
| return { | |
| "stage1_class": main_class, | |
| "stage1_confidence": float(np.max(stage1_pred)), | |
| "stage2_class": sub_class, | |
| "stage2_confidence": float(np.max(sub_pred)), | |
| "final_prediction": f"{main_class.split(' - ')[1]} → {sub_class}" | |
| } | |
| classifier = HierarchicalClassifier( | |
| stage1_model, abnormal_model, normal_model, | |
| stage1_classes, abnormal_classes, normal_classes | |
| ) | |
| # ---------------- 4. Prediction Function ----------- # | |
| def predict_washing_machine_sound(audio_filepath): | |
| if audio_filepath is None: | |
| return "Please upload an audio file first.", None | |
| print(f"Processing file: {audio_filepath}") | |
| spec_path = save_mel_spectrogram(audio_filepath, TEMP_DIR) | |
| if not spec_path: | |
| return "❌ Could not generate spectrogram from the audio file.", None | |
| result = classifier.predict(spec_path) | |
| output_text = ( | |
| f"🎯 Final Prediction: {result['final_prediction']}\n\n" | |
| f"Confidence Scores:\n" | |
| f"--------------------\n" | |
| f"Stage 1 ({result['stage1_class']}): {result['stage1_confidence']:.4f}\n" | |
| f"Stage 2 ({result['stage2_class']}): {result['stage2_confidence']:.4f}" | |
| ) | |
| return output_text, spec_path | |
| # ---------------- 5. Gradio Interface -------------- # | |
| if __name__ == "__main__": | |
| demo = gr.Interface( | |
| fn=predict_washing_machine_sound, | |
| inputs=gr.Audio(type="filepath", label="Upload Washing-Machine Audio (.wav)"), | |
| outputs=[ | |
| gr.Textbox(label="Prediction Result"), | |
| gr.Image(label="Generated Mel-Spectrogram") | |
| ], | |
| title="Washing-Machine Sound Classifier", | |
| description="Upload a WAV file of washing-machine audio to classify its operation status.", | |
| allow_flagging="never", | |
| # examples=[] # ← removed local file examples | |
| ) | |
| demo.launch() | |
| # Cleanup temp dir after app stops | |
| try: | |
| shutil.rmtree(TEMP_DIR) | |
| print("✅ Cleaned up temporary files.") | |
| except Exception as e: | |
| print(f"⚠️ Cleanup warning: {e}") | |