jam-tracks / backend /services /key_detector.py
Mina Emadi
updated the MVP-Initial upload
a0fcd39
"""Key detection using essentia ensemble with fallback to librosa."""
import numpy as np
from ..utils.audio_utils import to_mono, to_float32
from ..utils.music_theory import KEY_NAMES
# Try to import essentia, fall back to librosa if not available
try:
import essentia.standard as es
ESSENTIA_AVAILABLE = True
except ImportError:
ESSENTIA_AVAILABLE = False
import librosa
# Key profiles for fallback detection (Temperley - pop/rock optimized)
TEMPERLEY_MAJOR = np.array([5.0, 2.0, 3.5, 2.0, 4.5, 4.0, 2.0, 4.5, 2.0, 3.5, 1.5, 4.0])
TEMPERLEY_MINOR = np.array([5.0, 2.0, 3.5, 4.5, 2.0, 4.0, 2.0, 4.5, 3.5, 2.0, 1.5, 4.0])
def detect_key(
audio: np.ndarray,
sr: int,
bass_audio: np.ndarray = None
) -> dict:
"""
Detect musical key from audio using essentia ensemble or librosa.
Args:
audio: Audio array
sr: Sample rate
bass_audio: Optional bass stem to improve accuracy
Returns:
dict with "key", "mode", "confidence"
"""
# Use first 60 seconds only for speed
max_samples = sr * 60
if len(audio) > max_samples:
audio = audio[:max_samples]
if bass_audio is not None and len(bass_audio) > max_samples:
bass_audio = bass_audio[:max_samples]
# Convert to mono float32
audio_mono = to_float32(to_mono(audio))
if ESSENTIA_AVAILABLE:
return _detect_key_essentia(audio_mono, sr, bass_audio)
else:
return _detect_key_librosa(audio_mono, sr, bass_audio)
def _detect_key_essentia(
audio: np.ndarray,
sr: int,
bass_audio: np.ndarray = None
) -> dict:
"""Key detection using essentia ensemble."""
# Resample to 44100 if needed
if sr != 44100:
resampler = es.Resample(inputSampleRate=sr, outputSampleRate=44100)
audio = resampler(audio)
if bass_audio is not None:
bass_mono = to_float32(to_mono(bass_audio))
bass_audio = resampler(bass_mono)
# Run ensemble with 4 profiles
profiles = ["temperley", "krumhansl", "edma", "bgate"]
votes = {} # (key, mode) -> total weight
for profile in profiles:
key_extractor = es.KeyExtractor(profileType=profile)
key, scale, strength = key_extractor(audio)
vote_key = (key, scale)
votes[vote_key] = votes.get(vote_key, 0) + strength
# If bass audio provided, run ensemble on bass
if bass_audio is not None:
bass_mono = to_float32(to_mono(bass_audio)) if bass_audio.ndim == 2 else bass_audio
bass_votes = {}
for profile in profiles:
key_extractor = es.KeyExtractor(profileType=profile)
key, scale, strength = key_extractor(bass_mono)
bass_votes[(key, scale)] = bass_votes.get((key, scale), 0) + strength
# Find bass winner
if bass_votes:
bass_winner = max(bass_votes.keys(), key=lambda k: bass_votes[k])
# If bass agrees with main, boost confidence
# If bass disagrees and has strong confidence, prefer bass
bass_conf = bass_votes[bass_winner] / sum(bass_votes.values())
if bass_conf > 0.3:
# Add bass votes with weight
for key_mode, weight in bass_votes.items():
votes[key_mode] = votes.get(key_mode, 0) + weight * 0.5
# Find winner
winner = max(votes.keys(), key=lambda k: votes[k])
total_weight = sum(votes.values())
confidence = votes[winner] / total_weight if total_weight > 0 else 0
return {
"key": winner[0],
"mode": winner[1],
"confidence": round(float(confidence), 3)
}
def _detect_key_librosa(
audio: np.ndarray,
sr: int,
bass_audio: np.ndarray = None
) -> dict:
"""Fallback key detection using librosa chroma features."""
# Compute chroma features
chroma = librosa.feature.chroma_cqt(y=audio, sr=sr)
chroma_mean = np.mean(chroma, axis=1)
# Normalize
chroma_mean = chroma_mean / np.sum(chroma_mean)
# Correlate with key profiles
best_key = None
best_mode = None
best_corr = -1
for semitones in range(12):
key_name = KEY_NAMES[semitones]
# Rotate profiles
rotated_major = np.roll(TEMPERLEY_MAJOR, semitones)
rotated_minor = np.roll(TEMPERLEY_MINOR, semitones)
# Normalize
rotated_major = rotated_major / np.sum(rotated_major)
rotated_minor = rotated_minor / np.sum(rotated_minor)
# Correlate
corr_major = np.corrcoef(chroma_mean, rotated_major)[0, 1]
corr_minor = np.corrcoef(chroma_mean, rotated_minor)[0, 1]
if corr_major > best_corr:
best_corr = corr_major
best_key = key_name
best_mode = "major"
if corr_minor > best_corr:
best_corr = corr_minor
best_key = key_name
best_mode = "minor"
# If bass audio provided, combine results
if bass_audio is not None:
bass_mono = to_float32(to_mono(bass_audio))
bass_chroma = librosa.feature.chroma_cqt(y=bass_mono, sr=sr)
bass_chroma_mean = np.mean(bass_chroma, axis=1)
bass_chroma_mean = bass_chroma_mean / np.sum(bass_chroma_mean)
# Weight combined chroma
combined = (chroma_mean * 0.6 + bass_chroma_mean * 0.4)
combined = combined / np.sum(combined)
# Re-correlate
for semitones in range(12):
key_name = KEY_NAMES[semitones]
rotated_major = np.roll(TEMPERLEY_MAJOR, semitones)
rotated_minor = np.roll(TEMPERLEY_MINOR, semitones)
rotated_major = rotated_major / np.sum(rotated_major)
rotated_minor = rotated_minor / np.sum(rotated_minor)
corr_major = np.corrcoef(combined, rotated_major)[0, 1]
corr_minor = np.corrcoef(combined, rotated_minor)[0, 1]
if corr_major > best_corr:
best_corr = corr_major
best_key = key_name
best_mode = "major"
if corr_minor > best_corr:
best_corr = corr_minor
best_key = key_name
best_mode = "minor"
confidence = (best_corr + 1) / 2 # Map correlation [-1, 1] to [0, 1]
return {
"key": best_key,
"mode": best_mode,
"confidence": round(float(confidence), 3)
}