dev_caio / models /audio_analyzer.py
Chaitanya-aitf's picture
Initializing project from local
ad4e58a verified
"""
ShortSmith v2 - Audio Analyzer Module
Audio feature extraction and hype scoring using:
- Librosa for basic audio features (MVP)
- Wav2Vec 2.0 for advanced audio understanding (optional)
Features extracted:
- RMS energy (volume/loudness)
- Spectral flux (sudden changes, beat drops)
- Spectral centroid (brightness, crowd noise)
- Onset strength (beats, impacts)
- Speech activity detection
"""
from pathlib import Path
from typing import List, Optional, Tuple, Dict
from dataclasses import dataclass
import numpy as np
from utils.logger import get_logger, LogTimer
from utils.helpers import ModelLoadError, InferenceError, normalize_scores, batch_list
from config import get_config, ModelConfig
logger = get_logger("models.audio_analyzer")
@dataclass
class AudioFeatures:
"""Audio features for a segment of audio."""
timestamp: float # Start time in seconds
duration: float # Segment duration
rms_energy: float # Root mean square energy (0-1)
spectral_flux: float # Spectral change rate (0-1)
spectral_centroid: float # Frequency centroid (0-1)
onset_strength: float # Beat/impact strength (0-1)
zero_crossing_rate: float # ZCR (speech indicator) (0-1)
# Optional advanced features
speech_probability: float = 0.0 # From Wav2Vec if available
@property
def energy_score(self) -> float:
"""Combined energy-based hype indicator."""
return (self.rms_energy * 0.4 + self.onset_strength * 0.4 +
self.spectral_flux * 0.2)
@property
def excitement_score(self) -> float:
"""Overall audio excitement score."""
return (self.rms_energy * 0.3 + self.spectral_flux * 0.25 +
self.onset_strength * 0.25 + self.spectral_centroid * 0.2)
@dataclass
class AudioSegmentScore:
"""Hype score for an audio segment."""
start_time: float
end_time: float
score: float # Overall hype score (0-1)
features: AudioFeatures # Underlying features
@property
def duration(self) -> float:
return self.end_time - self.start_time
class AudioAnalyzer:
"""
Audio analysis for hype detection.
Uses Librosa for feature extraction and optionally Wav2Vec 2.0
for advanced semantic understanding.
"""
def __init__(
self,
config: Optional[ModelConfig] = None,
use_advanced: Optional[bool] = None,
):
"""
Initialize audio analyzer.
Args:
config: Model configuration (uses default if None)
use_advanced: Override config to use Wav2Vec 2.0
Raises:
ImportError: If librosa is not installed
"""
self.config = config or get_config().model
self.use_advanced = use_advanced if use_advanced is not None else self.config.use_advanced_audio
self._librosa = None
self._wav2vec_model = None
self._wav2vec_processor = None
# Initialize librosa (required)
self._init_librosa()
# Initialize Wav2Vec if requested
if self.use_advanced:
self._init_wav2vec()
logger.info(f"AudioAnalyzer initialized (advanced={self.use_advanced})")
def _init_librosa(self) -> None:
"""Initialize librosa library."""
try:
import librosa
self._librosa = librosa
except ImportError as e:
raise ImportError(
"Librosa is required for audio analysis. "
"Install with: pip install librosa"
) from e
def _init_wav2vec(self) -> None:
"""Initialize Wav2Vec 2.0 model."""
try:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
logger.info("Loading Wav2Vec 2.0 model...")
self._wav2vec_processor = Wav2Vec2Processor.from_pretrained(
self.config.audio_model_id
)
self._wav2vec_model = Wav2Vec2Model.from_pretrained(
self.config.audio_model_id
)
# Move to device
device = self.config.device
if device == "cuda":
import torch
if torch.cuda.is_available():
self._wav2vec_model = self._wav2vec_model.cuda()
self._wav2vec_model.eval()
logger.info("Wav2Vec 2.0 model loaded successfully")
except Exception as e:
logger.warning(f"Failed to load Wav2Vec 2.0, falling back to Librosa only: {e}")
self.use_advanced = False
def load_audio(
self,
audio_path: str | Path,
sample_rate: int = 22050,
mono: bool = True,
) -> Tuple[np.ndarray, int]:
"""
Load audio file.
Args:
audio_path: Path to audio file
sample_rate: Target sample rate
mono: Convert to mono if True
Returns:
Tuple of (audio_array, sample_rate)
Raises:
InferenceError: If audio loading fails
"""
try:
audio, sr = self._librosa.load(
str(audio_path),
sr=sample_rate,
mono=mono,
)
logger.debug(f"Loaded audio: {len(audio)/sr:.1f}s at {sr}Hz")
return audio, sr
except Exception as e:
raise InferenceError(f"Failed to load audio: {e}") from e
def extract_features(
self,
audio: np.ndarray,
sample_rate: int,
segment_duration: float = 1.0,
hop_duration: float = 0.5,
) -> List[AudioFeatures]:
"""
Extract audio features for overlapping segments.
Args:
audio: Audio array
sample_rate: Sample rate
segment_duration: Duration of each segment in seconds
hop_duration: Hop between segments in seconds
Returns:
List of AudioFeatures for each segment
"""
with LogTimer(logger, "Extracting audio features"):
duration = len(audio) / sample_rate
segment_samples = int(segment_duration * sample_rate)
hop_samples = int(hop_duration * sample_rate)
features = []
position = 0
timestamp = 0.0
while position + segment_samples <= len(audio):
segment = audio[position:position + segment_samples]
try:
feat = self._extract_segment_features(
segment, sample_rate, timestamp, segment_duration
)
features.append(feat)
except Exception as e:
logger.warning(f"Failed to extract features at {timestamp}s: {e}")
position += hop_samples
timestamp += hop_duration
logger.info(f"Extracted features for {len(features)} segments")
return features
def _extract_segment_features(
self,
segment: np.ndarray,
sample_rate: int,
timestamp: float,
duration: float,
) -> AudioFeatures:
"""Extract features from a single audio segment."""
librosa = self._librosa
# RMS energy (loudness)
rms = librosa.feature.rms(y=segment)[0]
rms_mean = float(np.mean(rms))
# Spectral flux (change rate)
spec = np.abs(librosa.stft(segment))
flux = np.mean(np.diff(spec, axis=1) ** 2)
flux_normalized = min(1.0, flux / 100) # Normalize
# Spectral centroid (brightness)
centroid = librosa.feature.spectral_centroid(y=segment, sr=sample_rate)[0]
centroid_mean = float(np.mean(centroid))
centroid_normalized = min(1.0, centroid_mean / 8000) # Normalize
# Onset strength (beats/impacts)
onset_env = librosa.onset.onset_strength(y=segment, sr=sample_rate)
onset_mean = float(np.mean(onset_env))
onset_normalized = min(1.0, onset_mean / 5) # Normalize
# Zero crossing rate
zcr = librosa.feature.zero_crossing_rate(segment)[0]
zcr_mean = float(np.mean(zcr))
return AudioFeatures(
timestamp=timestamp,
duration=duration,
rms_energy=min(1.0, rms_mean * 5), # Scale up
spectral_flux=flux_normalized,
spectral_centroid=centroid_normalized,
onset_strength=onset_normalized,
zero_crossing_rate=zcr_mean,
)
def analyze_file(
self,
audio_path: str | Path,
segment_duration: float = 1.0,
hop_duration: float = 0.5,
) -> List[AudioFeatures]:
"""
Analyze an audio file and extract features.
Args:
audio_path: Path to audio file
segment_duration: Duration of each segment
hop_duration: Hop between segments
Returns:
List of AudioFeatures for the file
"""
audio, sr = self.load_audio(audio_path)
return self.extract_features(audio, sr, segment_duration, hop_duration)
def compute_hype_scores(
self,
features: List[AudioFeatures],
window_size: int = 5,
) -> List[AudioSegmentScore]:
"""
Compute hype scores from audio features.
Uses a sliding window to smooth scores and identify
sustained high-energy regions.
Args:
features: List of AudioFeatures
window_size: Smoothing window size
Returns:
List of AudioSegmentScore objects
"""
if not features:
return []
with LogTimer(logger, "Computing audio hype scores"):
# Compute raw excitement scores
raw_scores = [f.excitement_score for f in features]
# Apply smoothing
smoothed = self._smooth_scores(raw_scores, window_size)
# Normalize to 0-1
normalized = normalize_scores(smoothed)
# Create score objects
scores = []
for feat, score in zip(features, normalized):
scores.append(AudioSegmentScore(
start_time=feat.timestamp,
end_time=feat.timestamp + feat.duration,
score=score,
features=feat,
))
return scores
def _smooth_scores(
self,
scores: List[float],
window_size: int,
) -> List[float]:
"""Apply moving average smoothing to scores."""
if len(scores) < window_size:
return scores
kernel = np.ones(window_size) / window_size
padded = np.pad(scores, (window_size // 2, window_size // 2), mode='edge')
smoothed = np.convolve(padded, kernel, mode='valid')
return smoothed.tolist()
def detect_peaks(
self,
scores: List[AudioSegmentScore],
threshold: float = 0.6,
min_duration: float = 3.0,
) -> List[Tuple[float, float, float]]:
"""
Detect peak regions in audio hype.
Args:
scores: List of AudioSegmentScore objects
threshold: Minimum score to consider a peak
min_duration: Minimum peak duration in seconds
Returns:
List of (start_time, end_time, peak_score) tuples
"""
if not scores:
return []
peaks = []
in_peak = False
peak_start = 0.0
peak_max = 0.0
for score in scores:
if score.score >= threshold:
if not in_peak:
in_peak = True
peak_start = score.start_time
peak_max = score.score
else:
peak_max = max(peak_max, score.score)
else:
if in_peak:
peak_end = score.start_time
if peak_end - peak_start >= min_duration:
peaks.append((peak_start, peak_end, peak_max))
in_peak = False
# Handle peak at end
if in_peak:
peak_end = scores[-1].end_time
if peak_end - peak_start >= min_duration:
peaks.append((peak_start, peak_end, peak_max))
logger.info(f"Detected {len(peaks)} audio peaks above threshold {threshold}")
return peaks
def get_beat_timestamps(
self,
audio: np.ndarray,
sample_rate: int,
) -> List[float]:
"""
Detect beat timestamps in audio.
Args:
audio: Audio array
sample_rate: Sample rate
Returns:
List of beat timestamps in seconds
"""
try:
tempo, beats = self._librosa.beat.beat_track(y=audio, sr=sample_rate)
beat_times = self._librosa.frames_to_time(beats, sr=sample_rate)
logger.debug(f"Detected {len(beat_times)} beats at {tempo:.1f} BPM")
return beat_times.tolist()
except Exception as e:
logger.warning(f"Beat detection failed: {e}")
return []
def get_audio_embedding(
self,
audio: np.ndarray,
sample_rate: int = 16000,
) -> Optional[np.ndarray]:
"""
Get Wav2Vec 2.0 embedding for audio segment.
Only available if use_advanced=True.
Args:
audio: Audio array (should be 16kHz)
sample_rate: Sample rate
Returns:
Embedding array or None if not available
"""
if not self.use_advanced or self._wav2vec_model is None:
return None
try:
import torch
# Resample if needed
if sample_rate != 16000:
audio = self._librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
# Process
inputs = self._wav2vec_processor(
audio, sampling_rate=16000, return_tensors="pt"
)
if self.config.device == "cuda" and torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = self._wav2vec_model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
return embedding[0]
except Exception as e:
logger.warning(f"Wav2Vec embedding extraction failed: {e}")
return None
def compare_audio_similarity(
self,
embedding1: np.ndarray,
embedding2: np.ndarray,
) -> float:
"""
Compare two audio embeddings using cosine similarity.
Args:
embedding1: First embedding
embedding2: Second embedding
Returns:
Similarity score (0-1)
"""
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
return float(np.dot(embedding1, embedding2) / (norm1 * norm2))
# Export public interface
__all__ = ["AudioAnalyzer", "AudioFeatures", "AudioSegmentScore"]