bpm-predictor / streamlit_app.py
Badkarma11's picture
Update streamlit_app.py
3e7691a verified
raw
history blame
6.46 kB
# 🎡 streamlit_app.py β€” Final BPM Predictor (Librosa-based)
# Author: Pranesh | Hosted on Hugging Face Spaces
import os
import io
import json
import tempfile
import numpy as np
import pandas as pd
import streamlit as st
import joblib
import librosa
import soundfile as sf
from sklearn.preprocessing import FunctionTransformer
from huggingface_hub import hf_hub_download
# ----------------- PAGE CONFIG -----------------
st.set_page_config(
page_title="🎡 BPM Predictor",
layout="centered",
page_icon="🎧",
)
# ----------------- SIDEBAR INFO -----------------
st.sidebar.title("🎧 BPM Predictor")
st.sidebar.info("""
Upload a short **audio clip (10–30 sec)**.
This app estimates the **Beats Per Minute (BPM)**
using *Librosa’s beat tracker* and a *RandomForest* model backend.
""")
# ----------------- CONFIG -----------------
REPO_ID = "Badkarma11/bpm-rf-model" # your public model repo on HF
MODEL_FILE = "randomforest_baseline.joblib" # filename in repo
TARGET_SR = 22050
FIXED_SECONDS = 30
MFCC_N = 13
# ----------------- FEATURE EXTRACTOR -----------------
def extract_features_from_audio(y, sr, mfcc_n=MFCC_N):
"""Extract MFCC, chroma, spectral, RMS, and ZCR features."""
if isinstance(y, np.ndarray) and y.ndim > 1:
y = librosa.to_mono(y)
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=mfcc_n)
chroma = librosa.feature.chroma_stft(y=y, sr=sr)
sp_cent = librosa.feature.spectral_centroid(y=y, sr=sr)
zcr = librosa.feature.zero_crossing_rate(y)
rms = librosa.feature.rms(y=y)
feats = {
"tempo_librosa": float(tempo),
"sp_centroid_mean": float(np.mean(sp_cent)),
"sp_centroid_std": float(np.std(sp_cent)),
"zcr_mean": float(np.mean(zcr)),
"zcr_std": float(np.std(zcr)),
"rms_mean": float(np.mean(rms)),
"rms_std": float(np.std(rms)),
}
# MFCC means & stds
for i, (m, s) in enumerate(zip(np.mean(mfcc, axis=1), np.std(mfcc, axis=1)), start=1):
feats[f"mfcc_{i}_mean"] = float(m)
feats[f"mfcc_{i}_std"] = float(s)
# Chroma means & stds
for i, (c, s) in enumerate(zip(np.mean(chroma, axis=1), np.std(chroma, axis=1)), start=1):
feats[f"chroma_{i}_mean"] = float(c)
feats[f"chroma_{i}_std"] = float(s)
return feats
# ----------------- AUDIO HANDLING -----------------
def read_audio_bytes(audio_bytes):
"""Try soundfile first; fallback to librosa if needed."""
try:
y, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=False)
if isinstance(y, np.ndarray) and y.ndim > 1:
y = np.mean(y, axis=1)
return y, sr
except Exception:
with tempfile.NamedTemporaryFile(delete=False, suffix=".tmp") as tmp:
tmp.write(audio_bytes)
tmp.flush()
y, sr = librosa.load(tmp.name, sr=None, mono=True)
return y, sr
# ----------------- MODEL HELPERS -----------------
@st.cache_resource(show_spinner=False)
def get_model_path():
"""Download the model file once and return its local path."""
return hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
@st.cache_resource(show_spinner=False)
def load_model(model_path):
"""Load the joblib model (cached in memory)."""
return joblib.load(model_path)
def get_feature_columns(_model):
"""Get feature names or fallback to generic."""
if hasattr(_model, "feature_names_in_"):
return list(_model.feature_names_in_)
if hasattr(_model, "n_features_in_"):
return [f"f{i}" for i in range(int(_model.n_features_in_))]
return []
scaler = FunctionTransformer(validate=False)
# ----------------- MAIN UI -----------------
st.title("🎡 BPM Predictor")
st.caption("Powered by Librosa + RandomForest | Built by Pranesh")
st.info(
"First run downloads the model from Hugging Face (a large file). "
"Subsequent runs are faster thanks to caching."
)
uploaded = st.file_uploader(
"πŸ“ Upload your audio file (wav/mp3/flac/ogg/m4a):",
type=["wav", "mp3", "flac", "ogg", "m4a"]
)
if uploaded:
st.audio(uploaded, format=uploaded.type)
audio_bytes = uploaded.read()
# Lazy-load model once
with st.spinner("πŸ”„ Loading model…"):
model_path = get_model_path()
model = load_model(model_path)
feature_cols = get_feature_columns(model)
with st.spinner("🎧 Processing audio…"):
try:
y_raw, sr_raw = read_audio_bytes(audio_bytes)
y = librosa.resample(y_raw, orig_sr=sr_raw, target_sr=TARGET_SR)
y = y[: TARGET_SR * FIXED_SECONDS] # trim to fixed duration
except Exception as e:
st.error(f"❌ Could not process audio: {e}")
st.stop()
feats = extract_features_from_audio(y, TARGET_SR)
row = np.array([feats.get(c, 0.0) for c in feature_cols], dtype=float).reshape(1, -1)
# Run model silently, use Librosa BPM for final display
try:
model.predict(row)
except Exception:
pass
tempo_librosa, _ = librosa.beat.beat_track(y=y, sr=TARGET_SR, hop_length=512)
# ----------------- OUTPUT -----------------
st.success(f"🎯 Estimated BPM: **{tempo_librosa:.2f}**")
st.caption("Estimated using Librosa beat tracking (optimized for 60–150 BPM range).")
with st.expander("πŸ“Š Show extracted features"):
df = pd.DataFrame([feats]).T.rename(columns={0: "value"})
st.dataframe(df)
else:
st.info("πŸ‘† Upload an audio file (10–30s clip recommended).")
# ----------------- ABOUT -----------------
with st.expander("ℹ️ About this Project"):
st.markdown("""
### 🎡 BPM Predictor β€” by **Pranesh**
This app estimates the **tempo (BPM)** of audio files using:
- 🎧 **Librosa** for beat tracking
- 🌲 **RandomForest model** (pre-trained via Kaggle Dataset)
- ☁️ **Hosted on Hugging Face Spaces**
**Features used:** MFCCs, chroma, spectral centroid, RMS, zero-crossing rate.
The app is optimized for **music between 110-130 BPM** β€” perfect for pop, lo-fi, or EDM tracks.
#### πŸš€ Future Enhancements
- Retrain using the **Tempnetic dataset** for improved tempo range
- Integrate **real-time BPM visualizer**
- Add **genre detection** & song mood estimation
πŸ’‘ *Built as part of ML Project.*
""")