Spaces:
Sleeping
Sleeping
| import torch | |
| import librosa | |
| import numpy as np | |
| from transformers import Wav2Vec2Processor, Wav2Vec2Model | |
| from pathlib import Path | |
| class AudioFeatureExtractor: | |
| def __init__(self, model_name="facebook/wav2vec2-base-960h", cache_dir="models/hub"): | |
| """ | |
| Initializes the Wav2Vec2 extractor with local caching. | |
| """ | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"Loading model: {model_name}...") | |
| print(f"Cache directory: {self.cache_dir.absolute()}") | |
| # Load processor and model with explicit cache_dir | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=self.cache_dir) | |
| self.model = Wav2Vec2Model.from_pretrained(model_name, cache_dir=self.cache_dir) | |
| # Move to GPU if available | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print(f"Model loaded on {self.device}") | |
| def extract(self, audio_path): | |
| """ | |
| Extracts a single 768-dim embedding for an audio file. | |
| """ | |
| try: | |
| # 1. Load audio (Wav2Vec2 expects 16kHz) | |
| speech, sr = librosa.load(audio_path, sr=16000) | |
| # 2. Preprocess | |
| inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt", padding=True) | |
| input_values = inputs.input_values.to(self.device) | |
| # 3. Inference (No gradient needed) | |
| with torch.no_grad(): | |
| outputs = self.model(input_values) | |
| # We use the last_hidden_state and mean-pool over the time dimension | |
| # Resulting shape: (1, 768) | |
| embeddings = torch.mean(outputs.last_hidden_state, dim=1) | |
| return embeddings.cpu().numpy().flatten() | |
| except Exception as e: | |
| print(f"Error extracting features from {audio_path}: {e}") | |
| return None | |
| if __name__ == "__main__": | |
| # Test on a single file | |
| import pandas as pd | |
| metadata_path = "data/processed/metadata.csv" | |
| if Path(metadata_path).exists(): | |
| df = pd.read_csv(metadata_path) | |
| sample_path = df.iloc[0]['path'] | |
| extractor = AudioFeatureExtractor() | |
| embedding = extractor.extract(sample_path) | |
| if embedding is not None: | |
| print(f"\nSuccess!") | |
| print(f"File: {sample_path}") | |
| print(f"Embedding shape: {embedding.shape}") | |
| print(f"First 5 values: {embedding[:5]}") | |
| else: | |
| print("Metadata not found. Please run harmonization first.") | |