File size: 2,729 Bytes
6d719fd
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ec0b6
 
6d719fd
 
 
 
 
 
 
 
 
 
d0ec0b6
6d719fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ec0b6
6d719fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c3ac7
6d719fd
d0ec0b6
6d719fd
 
d0ec0b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import librosa
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from pathlib import Path

class AudioFeatureExtractor:
    def __init__(self, model_name="facebook/wav2vec2-base-960h", cache_dir="models/hub"):
        """
        Initializes the Wav2Vec2 extractor with local caching.
        """
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        print(f"Loading model: {model_name}...")
        print(f"Cache directory: {self.cache_dir.absolute()}")
        
        # Load processor and model with explicit cache_dir
        self.processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=self.cache_dir)
        self.model = Wav2Vec2Model.from_pretrained(model_name, cache_dir=self.cache_dir)
        
        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        
        print(f"Model loaded on {self.device}")

    def extract(self, audio_path):
        """
        Extracts a single 768-dim embedding for an audio file.
        """
        try:
            # 1. Load audio (Wav2Vec2 expects 16kHz)
            speech, sr = librosa.load(audio_path, sr=16000)
            
            # 2. Preprocess
            inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
            input_values = inputs.input_values.to(self.device)
            
            # 3. Inference (No gradient needed)
            with torch.no_grad():
                outputs = self.model(input_values)
                # We use the last_hidden_state and mean-pool over the time dimension
                # Resulting shape: (1, 768)
                embeddings = torch.mean(outputs.last_hidden_state, dim=1)
            
            return embeddings.cpu().numpy().flatten()
            
        except Exception as e:
            print(f"Error extracting features from {audio_path}: {e}")
            return None

if __name__ == "__main__":
    # Test on a single file
    import pandas as pd
    
    metadata_path = "data/processed/metadata.csv"
    if Path(metadata_path).exists():
        df = pd.read_csv(metadata_path)
        sample_path = df.iloc[0]['path']
        
        extractor = AudioFeatureExtractor()
        embedding = extractor.extract(sample_path)
        
        if embedding is not None:
            print(f"\nSuccess!")
            print(f"File: {sample_path}")
            print(f"Embedding shape: {embedding.shape}")
            print(f"First 5 values: {embedding[:5]}")
    else:
        print("Metadata not found. Please run harmonization first.")