Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import librosa | |
| import numpy as np | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| from datetime import datetime | |
| # Optional: for microphone recording | |
| try: | |
| from streamlit_audiorecorder import audiorecorder | |
| RECORDING_ENABLED = True | |
| except ImportError: | |
| RECORDING_ENABLED = False | |
| # --- Firebase Initialization --- | |
| # On Hugging Face Spaces, add your service account JSON to Secrets.toml as "firebase_key" | |
| def init_firestore(): | |
| # Load service account from st.secrets | |
| cred_dict = st.secrets["firebase_key"] | |
| cred = credentials.Certificate(cred_dict) | |
| firebase_admin.initialize_app(cred) | |
| return firestore.client() | |
| db = init_firestore() | |
| # --- Load ML Model --- | |
| def load_model(): | |
| return tf.keras.models.load_model('Heart_ResNet.h5') | |
| model = load_model() | |
| # --- Audio Processing & Classification --- | |
| SAMPLE_RATE = 22050 | |
| DURATION = 10 | |
| INPUT_LEN = SAMPLE_RATE * DURATION | |
| CLASS_NAMES = ["artifact", "murmur", "normal"] | |
| def process_audio(raw_bytes): | |
| # Save raw bytes to temp file | |
| with open("temp.wav", "wb") as f: | |
| f.write(raw_bytes) | |
| X, sr = librosa.load("temp.wav", sr=SAMPLE_RATE, duration=DURATION) | |
| if len(X) < INPUT_LEN: | |
| X = np.pad(X, (0, INPUT_LEN - len(X)), mode='constant') | |
| mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52, n_fft=512, hop_length=256).T, axis=0) | |
| return mfccs, X, sr | |
| def classify(mfccs): | |
| feats = mfccs.reshape(1, 52, 1) | |
| preds = model.predict(feats) | |
| return {name: float(preds[0][i]) for i, name in enumerate(CLASS_NAMES)} | |
| # --- User Authentication Helpers --- | |
| def register_user(email, password): | |
| user_ref = db.collection("users").document(email) | |
| if user_ref.get().exists: | |
| return False, "User already exists" | |
| user_ref.set({"email": email, "password": password, "history": []}) | |
| return True, "Registration successful" | |
| def login_user(email, password): | |
| user_ref = db.collection("users").document(email) | |
| user = user_ref.get() | |
| if user.exists and user.to_dict().get("password") == password: | |
| return True, user.to_dict() | |
| return False, None | |
| # --- Firestore History Functions --- | |
| def save_history(email, result): | |
| record = {"timestamp": datetime.utcnow().isoformat(), "result": result} | |
| db.collection("users").document(email).update({ | |
| "history": firestore.ArrayUnion([record]) | |
| }) | |
| def load_history(email): | |
| user = db.collection("users").document(email).get().to_dict() | |
| return user.get("history", []) | |
| # --- Streamlit App Layout --- | |
| st.title("🩺 Heartbeat Sound Classifier with Firestore") | |
| # Session state for user | |
| if "user" not in st.session_state: | |
| st.session_state.user = None | |
| # Sidebar: Auth or Logout | |
| with st.sidebar: | |
| if st.session_state.user: | |
| st.markdown(f"**Logged in as:** {st.session_state.user['email']}") | |
| if st.button("Logout"): | |
| st.session_state.user = None | |
| st.experimental_rerun() | |
| else: | |
| tab = st.radio("Account", ["Login", "Register"]) | |
| email = st.text_input("Email") | |
| password = st.text_input("Password", type="password") | |
| if tab == "Register": | |
| if st.button("Sign Up"): | |
| success, msg = register_user(email, password) | |
| st.success(msg) if success else st.error(msg) | |
| else: | |
| if st.button("Login"): | |
| success, user = login_user(email, password) | |
| if success: | |
| st.session_state.user = user | |
| st.experimental_rerun() | |
| else: | |
| st.error("Invalid credentials") | |
| # Main: show after login | |
| if st.session_state.user: | |
| st.header("Upload or Record Your Heartbeat") | |
| mode = st.radio("Input Mode:", ["Upload File", "Record (mic)" if RECORDING_ENABLED else "Upload File"]) | |
| raw_audio = None | |
| if mode == "Upload File": | |
| up = st.file_uploader("Select WAV/MP3 file", type=["wav", "mp3"]) | |
| if up: | |
| raw_audio = up.read() | |
| st.audio(raw_audio, format='audio/wav') | |
| else: | |
| audio_data = audiorecorder() | |
| if audio_data is not None: | |
| raw_audio = audio_data.tobytes() | |
| st.audio(raw_audio, format='audio/wav') | |
| if raw_audio: | |
| if st.button("Classify Heartbeat"): | |
| with st.spinner("Analyzing..."): | |
| mfccs, waveform, sr = process_audio(raw_audio) | |
| results = classify(mfccs) | |
| save_history(st.session_state.user["email"], results) | |
| # Display metrics | |
| st.subheader("Results") | |
| cols = st.columns(len(CLASS_NAMES)) | |
| for col, name in zip(cols, CLASS_NAMES): | |
| col.metric(name.title(), f"{results[name]*100:.2f}%") | |
| # Waveform plot | |
| fig, ax = plt.subplots(figsize=(8, 3)) | |
| librosa.display.waveshow(waveform, sr=sr, ax=ax) | |
| ax.set(title="Heartbeat Waveform", xlabel="Time (s)", ylabel="Amplitude") | |
| st.pyplot(fig) | |
| # Show history | |
| st.header("Your Classification History") | |
| history = load_history(st.session_state.user["email"]) | |
| if history: | |
| for rec in sorted(history, key=lambda x: x['timestamp'], reverse=True): | |
| st.write(f"**{rec['timestamp']}**") | |
| st.json(rec['result']) | |
| else: | |
| st.info("No history yet. Classify your first heartbeat!") | |
| else: | |
| st.info("Please login or register to continue.") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("Built with ❤️ and deployed on Hugging Face Spaces") | |