import streamlit as st import librosa import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import firebase_admin from firebase_admin import credentials, firestore from datetime import datetime # Optional: for microphone recording try: from streamlit_audiorecorder import audiorecorder RECORDING_ENABLED = True except ImportError: RECORDING_ENABLED = False # --- Firebase Initialization --- # On Hugging Face Spaces, add your service account JSON to Secrets.toml as "firebase_key" @st.cache_resource def init_firestore(): # Load service account from st.secrets cred_dict = st.secrets["firebase_key"] cred = credentials.Certificate(cred_dict) firebase_admin.initialize_app(cred) return firestore.client() db = init_firestore() # --- Load ML Model --- @st.cache_resource def load_model(): return tf.keras.models.load_model('Heart_ResNet.h5') model = load_model() # --- Audio Processing & Classification --- SAMPLE_RATE = 22050 DURATION = 10 INPUT_LEN = SAMPLE_RATE * DURATION CLASS_NAMES = ["artifact", "murmur", "normal"] @st.cache def process_audio(raw_bytes): # Save raw bytes to temp file with open("temp.wav", "wb") as f: f.write(raw_bytes) X, sr = librosa.load("temp.wav", sr=SAMPLE_RATE, duration=DURATION) if len(X) < INPUT_LEN: X = np.pad(X, (0, INPUT_LEN - len(X)), mode='constant') mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52, n_fft=512, hop_length=256).T, axis=0) return mfccs, X, sr def classify(mfccs): feats = mfccs.reshape(1, 52, 1) preds = model.predict(feats) return {name: float(preds[0][i]) for i, name in enumerate(CLASS_NAMES)} # --- User Authentication Helpers --- def register_user(email, password): user_ref = db.collection("users").document(email) if user_ref.get().exists: return False, "User already exists" user_ref.set({"email": email, "password": password, "history": []}) return True, "Registration successful" def login_user(email, password): user_ref = db.collection("users").document(email) user = user_ref.get() if user.exists and user.to_dict().get("password") == password: return True, user.to_dict() return False, None # --- Firestore History Functions --- def save_history(email, result): record = {"timestamp": datetime.utcnow().isoformat(), "result": result} db.collection("users").document(email).update({ "history": firestore.ArrayUnion([record]) }) def load_history(email): user = db.collection("users").document(email).get().to_dict() return user.get("history", []) # --- Streamlit App Layout --- st.title("🩺 Heartbeat Sound Classifier with Firestore") # Session state for user if "user" not in st.session_state: st.session_state.user = None # Sidebar: Auth or Logout with st.sidebar: if st.session_state.user: st.markdown(f"**Logged in as:** {st.session_state.user['email']}") if st.button("Logout"): st.session_state.user = None st.experimental_rerun() else: tab = st.radio("Account", ["Login", "Register"]) email = st.text_input("Email") password = st.text_input("Password", type="password") if tab == "Register": if st.button("Sign Up"): success, msg = register_user(email, password) st.success(msg) if success else st.error(msg) else: if st.button("Login"): success, user = login_user(email, password) if success: st.session_state.user = user st.experimental_rerun() else: st.error("Invalid credentials") # Main: show after login if st.session_state.user: st.header("Upload or Record Your Heartbeat") mode = st.radio("Input Mode:", ["Upload File", "Record (mic)" if RECORDING_ENABLED else "Upload File"]) raw_audio = None if mode == "Upload File": up = st.file_uploader("Select WAV/MP3 file", type=["wav", "mp3"]) if up: raw_audio = up.read() st.audio(raw_audio, format='audio/wav') else: audio_data = audiorecorder() if audio_data is not None: raw_audio = audio_data.tobytes() st.audio(raw_audio, format='audio/wav') if raw_audio: if st.button("Classify Heartbeat"): with st.spinner("Analyzing..."): mfccs, waveform, sr = process_audio(raw_audio) results = classify(mfccs) save_history(st.session_state.user["email"], results) # Display metrics st.subheader("Results") cols = st.columns(len(CLASS_NAMES)) for col, name in zip(cols, CLASS_NAMES): col.metric(name.title(), f"{results[name]*100:.2f}%") # Waveform plot fig, ax = plt.subplots(figsize=(8, 3)) librosa.display.waveshow(waveform, sr=sr, ax=ax) ax.set(title="Heartbeat Waveform", xlabel="Time (s)", ylabel="Amplitude") st.pyplot(fig) # Show history st.header("Your Classification History") history = load_history(st.session_state.user["email"]) if history: for rec in sorted(history, key=lambda x: x['timestamp'], reverse=True): st.write(f"**{rec['timestamp']}**") st.json(rec['result']) else: st.info("No history yet. Classify your first heartbeat!") else: st.info("Please login or register to continue.") # Footer st.markdown("---") st.markdown("Built with ❤️ and deployed on Hugging Face Spaces")