S2S / extract_embeddings.py
Vikrantyadav11234's picture
Add files using upload-large-folder tool
b54655d verified
import torch
import torchaudio
import glob
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import warnings
import logging
# Configure logging
logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
# Configuration
AUDIO_DIR = "/home/vikrant/Conversational-AI-Model/embedding_vocoder/non_empty_wavs/*.wav"
EMBEDDING_DIR = "/home/vikrant/Conversational-AI-Model/embedding_vocoder/embeddings"
MODEL_NAME = "facebook/wav2vec2-large-lv60"
SAMPLE_RATE = 16000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize_models():
"""Initialize Wav2Vec 2.0 model and processor with XLA workaround"""
print("Loading Wav2Vec 2.0 model...")
# Disable XLA integration
os.environ["NO_XLA_IMPORT"] = "1"
# Initialize with explicit config to avoid XLA issues
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model = Wav2Vec2Model.from_pretrained(
MODEL_NAME,
attn_implementation="eager" # Force eager attention implementation
).to(DEVICE)
return processor, model
def process_audio_file(audio_path, processor, model):
"""Process single audio file and extract embeddings"""
try:
# Load and preprocess audio
waveform, orig_sr = torchaudio.load(audio_path)
# Convert to mono if stereo
if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample if necessary
if orig_sr != SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)
waveform = resampler(waveform)
# Normalize audio
waveform = waveform / torch.max(torch.abs(waveform))
# Process through Wav2Vec 2.0
with torch.no_grad():
inputs = processor(
waveform.squeeze().numpy(),
sampling_rate=SAMPLE_RATE,
return_tensors="pt"
).to(DEVICE)
outputs = model(**inputs)
# Extract embeddings (mean pooling over time axis)
embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy()
return embeddings
except Exception as e:
print(f"Error processing {audio_path}: {str(e)}")
return None
def generate_embeddings():
"""Main function to process all audio files"""
# Create output directory
Path(EMBEDDING_DIR).mkdir(parents=True, exist_ok=True)
# Get audio files
audio_files = glob.glob(AUDIO_DIR)
print(f"Found {len(audio_files)} audio files")
# Initialize models
processor, model = initialize_models()
# Process files
skipped_files = []
processed_count = 0
for audio_path in tqdm(audio_files, desc="Processing audio files"):
try:
embeddings = process_audio_file(audio_path, processor, model)
if embeddings is None:
skipped_files.append(audio_path)
continue
# Save embeddings
stem = Path(audio_path).stem
output_path = os.path.join(EMBEDDING_DIR, f"{stem}.npy")
np.save(output_path, embeddings)
processed_count += 1
except Exception as e:
skipped_files.append((audio_path, str(e)))
# Print summary
print(f"\nSuccessfully processed {processed_count}/{len(audio_files)} files")
if skipped_files:
print(f"\nFailed to process {len(skipped_files)} files:")
for item in skipped_files[:5]: # Show first 5 errors
if isinstance(item, tuple):
print(f"- {item[0]}: {item[1]}")
else:
print(f"- {item}")
if __name__ == "__main__":
# Check CUDA
if torch.cuda.is_available():
print("CUDA is available. Using GPU.")
else:
print("Using CPU.")
# Verify audio files exist
if not glob.glob(AUDIO_DIR):
print(f"No audio files found at {AUDIO_DIR}")
exit(1)
generate_embeddings()