vigilaudio / src /features /extractor.py
nice-bill's picture
improved feature extractor with normalization and caching
31c3ac7
import torch
import librosa
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from pathlib import Path
class AudioFeatureExtractor:
def __init__(self, model_name="facebook/wav2vec2-base-960h", cache_dir="models/hub"):
"""
Initializes the Wav2Vec2 extractor with local caching.
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading model: {model_name}...")
print(f"Cache directory: {self.cache_dir.absolute()}")
# Load processor and model with explicit cache_dir
self.processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=self.cache_dir)
self.model = Wav2Vec2Model.from_pretrained(model_name, cache_dir=self.cache_dir)
# Move to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
print(f"Model loaded on {self.device}")
def extract(self, audio_path):
"""
Extracts a single 768-dim embedding for an audio file.
"""
try:
# 1. Load audio (Wav2Vec2 expects 16kHz)
speech, sr = librosa.load(audio_path, sr=16000)
# 2. Preprocess
inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(self.device)
# 3. Inference (No gradient needed)
with torch.no_grad():
outputs = self.model(input_values)
# We use the last_hidden_state and mean-pool over the time dimension
# Resulting shape: (1, 768)
embeddings = torch.mean(outputs.last_hidden_state, dim=1)
return embeddings.cpu().numpy().flatten()
except Exception as e:
print(f"Error extracting features from {audio_path}: {e}")
return None
if __name__ == "__main__":
# Test on a single file
import pandas as pd
metadata_path = "data/processed/metadata.csv"
if Path(metadata_path).exists():
df = pd.read_csv(metadata_path)
sample_path = df.iloc[0]['path']
extractor = AudioFeatureExtractor()
embedding = extractor.extract(sample_path)
if embedding is not None:
print(f"\nSuccess!")
print(f"File: {sample_path}")
print(f"Embedding shape: {embedding.shape}")
print(f"First 5 values: {embedding[:5]}")
else:
print("Metadata not found. Please run harmonization first.")