Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import torch | |
| import torchaudio | |
| import torchvision | |
| from torch.utils.data import Dataset | |
| from concurrent.futures import ThreadPoolExecutor | |
| from preprocess import process_audio_data, process_image_data, resample_rate | |
| class PreprocessedDataset(Dataset): | |
| def __init__(self, data_dir): | |
| self.data_dir = data_dir | |
| self.samples = [ | |
| os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt") | |
| ] | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| sample_path = self.samples[idx] | |
| mfcc, image, label = torch.load(sample_path) | |
| # Process data | |
| mfcc = process_audio_data(mfcc, resample_rate) | |
| image = process_image_data(image) | |
| return mfcc, image, label | |
| def load_audio_file(audio_path): | |
| if not os.path.exists(audio_path): | |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") | |
| try: | |
| # Try the default torchaudio loader first | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| except Exception as e: | |
| print(f"Warning: Could not load {audio_path} with torchaudio: {e}") | |
| # Fall back to librosa (you'll need to install it: pip install librosa) | |
| try: | |
| import librosa | |
| import numpy as np | |
| waveform_np, sample_rate = librosa.load(audio_path, sr=None) | |
| # Convert to torch tensor with shape [1, length] to match torchaudio format | |
| waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float() | |
| print(f"Successfully loaded with librosa: {audio_path}") | |
| except Exception as final_e: | |
| raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}") | |
| return waveform, sample_rate | |
| def load_image_file(image_path): | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image file not found: {image_path}") | |
| image = torchvision.io.read_image(image_path) | |
| return image | |
| def process_sample(sample_path, save_dir): | |
| # Recursively search for audio and image files | |
| audio_files = [] | |
| image_files = [] | |
| # Walk through all subdirectories | |
| for root, _, files in os.walk(sample_path): | |
| for file in files: | |
| if file.lower().endswith(('.wav', '.mp3', '.flac')): | |
| audio_files.append(os.path.join(root, file)) | |
| elif file.lower().endswith(('.jpg', '.jpeg', '.png')): | |
| image_files.append(os.path.join(root, file)) | |
| if not audio_files: | |
| print(f"Warning: No audio file found in {sample_path}. Skipping this sample.") | |
| return | |
| if not image_files: | |
| print(f"Warning: No image file found in {sample_path}. Skipping this sample.") | |
| return | |
| # Use the first found audio and image files | |
| audio_path = audio_files[0] | |
| image_path = image_files[0] | |
| print(f"Processing audio: {audio_path}") | |
| print(f"Processing image: {image_path}") | |
| waveform, sample_rate = load_audio_file(audio_path) | |
| image = load_image_file(image_path) | |
| # Process data | |
| mfcc = process_audio_data(waveform, sample_rate) | |
| processed_image = process_image_data(image) | |
| # Save processed data | |
| save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt") | |
| torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path) | |
| print(f"Processed and saved: {save_path}") | |
| def process_and_save(data_dir, save_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] | |
| if not sample_paths: | |
| print(f"Warning: No sample directories found in {data_dir}") | |
| return | |
| print(f"Found {len(sample_paths)} sample directories to process") | |
| successful = 0 | |
| failed = 0 | |
| with ThreadPoolExecutor() as executor: | |
| futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths] | |
| for future in futures: | |
| try: | |
| future.result() # Wait for all threads to complete | |
| successful += 1 | |
| except Exception as e: | |
| failed += 1 | |
| print(f"Error processing a sample: {e}") | |
| print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Preprocess the dataset") | |
| parser.add_argument( | |
| "--data_dir", | |
| type=str, | |
| default="cleaned", | |
| help="Path to the cleaned dataset directory", | |
| ) | |
| parser.add_argument( | |
| "--save_dir", | |
| type=str, | |
| default="processed", | |
| help="Path to the processed dataset directory", | |
| ) | |
| args = parser.parse_args() | |
| print(f"Processing dataset from: {args.data_dir}") | |
| print(f"Saving processed data to: {args.save_dir}") | |
| process_and_save(args.data_dir, args.save_dir) | |
| print("Preprocessing complete") | |