| | import torch |
| | import torchaudio |
| | import glob |
| | from transformers import Wav2Vec2Processor, Wav2Vec2Model |
| | import numpy as np |
| | import os |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | import warnings |
| | import logging |
| |
|
| | |
| | logging.getLogger("transformers").setLevel(logging.ERROR) |
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| | AUDIO_DIR = "/home/vikrant/Conversational-AI-Model/embedding_vocoder/non_empty_wavs/*.wav" |
| | EMBEDDING_DIR = "/home/vikrant/Conversational-AI-Model/embedding_vocoder/embeddings" |
| | MODEL_NAME = "facebook/wav2vec2-large-lv60" |
| | SAMPLE_RATE = 16000 |
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | def initialize_models(): |
| | """Initialize Wav2Vec 2.0 model and processor with XLA workaround""" |
| | print("Loading Wav2Vec 2.0 model...") |
| | |
| | |
| | os.environ["NO_XLA_IMPORT"] = "1" |
| | |
| | |
| | processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) |
| | model = Wav2Vec2Model.from_pretrained( |
| | MODEL_NAME, |
| | attn_implementation="eager" |
| | ).to(DEVICE) |
| | |
| | return processor, model |
| |
|
| | def process_audio_file(audio_path, processor, model): |
| | """Process single audio file and extract embeddings""" |
| | try: |
| | |
| | waveform, orig_sr = torchaudio.load(audio_path) |
| | |
| | |
| | if waveform.dim() > 1 and waveform.shape[0] > 1: |
| | waveform = torch.mean(waveform, dim=0, keepdim=True) |
| | |
| | |
| | if orig_sr != SAMPLE_RATE: |
| | resampler = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE) |
| | waveform = resampler(waveform) |
| | |
| | |
| | waveform = waveform / torch.max(torch.abs(waveform)) |
| | |
| | |
| | with torch.no_grad(): |
| | inputs = processor( |
| | waveform.squeeze().numpy(), |
| | sampling_rate=SAMPLE_RATE, |
| | return_tensors="pt" |
| | ).to(DEVICE) |
| | |
| | outputs = model(**inputs) |
| | |
| | |
| | embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy() |
| | return embeddings |
| | |
| | except Exception as e: |
| | print(f"Error processing {audio_path}: {str(e)}") |
| | return None |
| |
|
| | def generate_embeddings(): |
| | """Main function to process all audio files""" |
| | |
| | Path(EMBEDDING_DIR).mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | audio_files = glob.glob(AUDIO_DIR) |
| | print(f"Found {len(audio_files)} audio files") |
| | |
| | |
| | processor, model = initialize_models() |
| | |
| | |
| | skipped_files = [] |
| | processed_count = 0 |
| | |
| | for audio_path in tqdm(audio_files, desc="Processing audio files"): |
| | try: |
| | embeddings = process_audio_file(audio_path, processor, model) |
| | if embeddings is None: |
| | skipped_files.append(audio_path) |
| | continue |
| | |
| | |
| | stem = Path(audio_path).stem |
| | output_path = os.path.join(EMBEDDING_DIR, f"{stem}.npy") |
| | np.save(output_path, embeddings) |
| | processed_count += 1 |
| | |
| | except Exception as e: |
| | skipped_files.append((audio_path, str(e))) |
| | |
| | |
| | print(f"\nSuccessfully processed {processed_count}/{len(audio_files)} files") |
| | if skipped_files: |
| | print(f"\nFailed to process {len(skipped_files)} files:") |
| | for item in skipped_files[:5]: |
| | if isinstance(item, tuple): |
| | print(f"- {item[0]}: {item[1]}") |
| | else: |
| | print(f"- {item}") |
| |
|
| | if __name__ == "__main__": |
| | |
| | if torch.cuda.is_available(): |
| | print("CUDA is available. Using GPU.") |
| | else: |
| | print("Using CPU.") |
| | |
| | |
| | if not glob.glob(AUDIO_DIR): |
| | print(f"No audio files found at {AUDIO_DIR}") |
| | exit(1) |
| | |
| | generate_embeddings() |