Spaces:
Sleeping
Sleeping
File size: 2,537 Bytes
2279ae0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, IterableDataset
from datasets import register
import datasets
class BaseWrapperAudioCAE:
"""Base wrapper for audio Convolutional Autoencoder (CAE) training.
Similar to the image wrapper, but for audio data.
"""
def __init__(
self,
dataset,
sample_rate=24000,
duration=0.38, # Duration in seconds
n_samples=None, # Alternative: specify exact number of samples
return_gt=True,
gt_sample_rate=None, # Ground truth sample rate (if different)
mono=True,
normalize=True,
return_coords=True, # Whether to return coordinate grids
):
self.dataset = datasets.make(dataset)
self.sample_rate = sample_rate
self.duration = duration
self.n_samples = int(duration * sample_rate)
self.return_gt = return_gt
self.gt_sample_rate = gt_sample_rate or sample_rate
self.mono = mono
self.normalize = normalize
self.return_coords = return_coords
def process(self, audio_data):
"""Process audio data for DiTo training.
Args:
audio_data: Dictionary with 'signal' key containing AudioSignal
or AudioSignal directly
"""
ret = {}
# Extract AudioSignal
if isinstance(audio_data, dict):
signal = audio_data['signal']
else:
signal = audio_data
# Normalize audio
audio_tensor = signal.audio_data # Shape: [channels, samples]
audio_tensor = audio_tensor.squeeze(0)
# Create input tensor
ret['inp'] = audio_tensor
if not self.return_gt:
return ret
ret['gt'] = audio_tensor
# print('audio_tensor shape: ', audio_tensor.shape)
return ret
@register('wrapper_audio_cae')
class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset):
"""Dataset wrapper for audio CAE training."""
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
return self.process(data)
@register('wrapper_audio_cae_iterable')
class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset):
"""Iterable dataset wrapper for audio CAE training."""
def __iter__(self):
for data in self.dataset:
yield self.process(data)
|