Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import IterableDataset | |
| from torch.fft import fft, fftshift | |
| import torch.nn.functional as F | |
| from itertools import tee | |
| import random | |
| import torchaudio.transforms as T | |
| import hashlib | |
| from typing import NamedTuple, Tuple, Union | |
| from .transforms import compute_all_features | |
| from scipy.signal import savgol_filter as savgol | |
| class WeightsBatch(NamedTuple): | |
| weights: Tuple | |
| biases: Tuple | |
| label: Union[torch.Tensor, int] | |
| def _assert_same_len(self): | |
| assert len(set([len(t) for t in self])) == 1 | |
| def as_dict(self): | |
| return self._asdict() | |
| def to(self, device): | |
| """move batch to device""" | |
| return self.__class__( | |
| weights=tuple(w.to(device) for w in self.weights), | |
| biases=tuple(w.to(device) for w in self.biases), | |
| label=self.label.to(device), | |
| ) | |
| def __len__(self): | |
| return len(self.weights[0]) | |
| class SplitDataset(IterableDataset): | |
| def __init__(self, dataset, is_train=True, train_ratio=0.8): | |
| self.dataset = dataset | |
| self.is_train = is_train | |
| self.train_ratio = train_ratio | |
| def __iter__(self): | |
| count = 0 | |
| for item in self.dataset: | |
| # For first train_ratio portion of items, yield to train | |
| # For remaining items, yield to validation | |
| is_train_item = count < int(self.train_ratio * 100) | |
| if is_train_item == self.is_train: | |
| yield item | |
| count = (count + 1) % 100 | |
| class FFTDataset(IterableDataset): | |
| def __init__(self, original_dataset, | |
| max_len=72000, | |
| orig_sample_rate=12000, | |
| target_sample_rate=3000, | |
| features=False): | |
| super().__init__() | |
| self.dataset = original_dataset | |
| self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate) | |
| self.target_sample_rate = target_sample_rate | |
| self.max_len = max_len | |
| self.features = features | |
| def normalize_audio(self, audio): | |
| """Normalize audio to [0, 1] range""" | |
| audio_min = audio.min() | |
| audio_max = audio.max() | |
| audio = (audio - audio_min) / (audio_max - audio_min) | |
| return audio | |
| def generate_unique_id(self, array): | |
| # Convert the array to bytes | |
| array_bytes = array.tobytes() | |
| # Hash the bytes using SHA256 | |
| hash_object = hashlib.sha256(array_bytes) | |
| # Return the hexadecimal representation of the hash | |
| return hash_object.hexdigest() | |
| def __iter__(self): | |
| for item in self.dataset: | |
| # audio_data = savgol(item['audio']['array'], 500, polyorder=1) | |
| audio_data = item['audio']['array'] | |
| # item['id'] = self.generate_unique_id(audio_data) | |
| audio_data = torch.tensor(audio_data).float() | |
| pad_len = self.max_len - len(audio_data) | |
| audio_data = F.pad(audio_data, (0, pad_len), mode='constant') | |
| audio_data = self.resampler(audio_data) | |
| audio_data = self.normalize_audio(audio_data) | |
| fft_data = fft(audio_data) | |
| magnitude = torch.abs(fft_data) | |
| phase = torch.angle(fft_data) | |
| if self.features: | |
| features = compute_all_features(audio_data, sample_rate=self.target_sample_rate) | |
| # features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()]) | |
| item['audio']['features'] = features | |
| magnitude_centered = fftshift(magnitude) | |
| phase_centered = fftshift(phase) | |
| # cwt = features['cwt_power'] | |
| # Optionally, remove the DC component | |
| magnitude_centered[len(magnitude_centered) // 2] = 0 # Set DC component to zero | |
| item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0) | |
| item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0) | |
| # item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0) | |
| item['audio']['array'] = torch.nan_to_num(audio_data, 0) | |
| # item['audio']['features'] = features | |
| yield item | |
| class AudioINRDataset(IterableDataset): | |
| def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True): | |
| """ | |
| Convert audio data into coordinate-value pairs for INR training. | |
| Args: | |
| original_dataset: Original audio dataset | |
| max_len: Maximum length of audio to process | |
| batch_size: Number of points to sample per audio clip | |
| normalize: Whether to normalize the audio values to [0, 1] | |
| """ | |
| self.dataset = original_dataset | |
| self.max_len = max_len | |
| self.dim = dim | |
| self.normalize = normalize | |
| self.sample_size = sample_size | |
| def get_coordinates(self, audio_len): | |
| """Generate time coordinates""" | |
| # Create normalized time coordinates in [0, 1] | |
| coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim) | |
| return coords # Shape: [audio_len, 1] | |
| def sample_points(self, coords, values): | |
| """Randomly sample points from the audio""" | |
| if len(coords) > self.sample_size: | |
| idx = torch.randperm(len(coords))[:self.sample_size] | |
| coords = coords[idx] | |
| values = values[idx] | |
| return coords, values | |
| def __iter__(self): | |
| for item in self.dataset: | |
| # Get audio data | |
| audio_data = torch.tensor(item['audio']['array']).float() | |
| # Generate coordinates | |
| coords = self.get_coordinates(len(audio_data)) | |
| item['audio']['coords'] = coords | |
| # Sample random points | |
| # coords, values = self.sample_points(coords, audio_data) | |
| # Create the INR training sample | |
| yield item |