| |
| |
| |
| |
| |
| """Base classes for the datasets that also provide non-audio metadata, |
| e.g. description, text transcription etc. |
| """ |
| from dataclasses import dataclass |
| import logging |
| import math |
| import re |
| import typing as tp |
|
|
| import torch |
|
|
| from .audio_dataset import AudioDataset, AudioMeta |
| from ..environment import AudioCraftEnvironment |
| from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _clusterify_meta(meta: AudioMeta) -> AudioMeta: |
| """Monkey-patch meta to match cluster specificities.""" |
| meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) |
| if meta.info_path is not None: |
| meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) |
| return meta |
|
|
|
|
| def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: |
| """Monkey-patch all meta to match cluster specificities.""" |
| return [_clusterify_meta(m) for m in meta] |
|
|
|
|
| @dataclass |
| class AudioInfo(SegmentWithAttributes): |
| """Dummy SegmentInfo with empty attributes. |
| |
| The InfoAudioDataset is expected to return metadata that inherits |
| from SegmentWithAttributes class and can return conditioning attributes. |
| |
| This basically guarantees all datasets will be compatible with current |
| solver that contain conditioners requiring this. |
| """ |
| audio_tokens: tp.Optional[torch.Tensor] = None |
|
|
| def to_condition_attributes(self) -> ConditioningAttributes: |
| return ConditioningAttributes() |
|
|
|
|
| class InfoAudioDataset(AudioDataset): |
| """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. |
| |
| See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. |
| """ |
| def __init__(self, meta: tp.List[AudioMeta], **kwargs): |
| super().__init__(clusterify_all_meta(meta), **kwargs) |
|
|
| def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: |
| if not self.return_info: |
| wav = super().__getitem__(index) |
| assert isinstance(wav, torch.Tensor) |
| return wav |
| wav, meta = super().__getitem__(index) |
| return wav, AudioInfo(**meta.to_dict()) |
|
|
|
|
| def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: |
| """Preprocess a single keyword or possible a list of keywords.""" |
| if isinstance(value, list): |
| return get_keyword_list(value) |
| else: |
| return get_keyword(value) |
|
|
|
|
| def get_string(value: tp.Optional[str]) -> tp.Optional[str]: |
| """Preprocess a single keyword.""" |
| if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': |
| return None |
| else: |
| return value.strip() |
|
|
|
|
| def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: |
| """Preprocess a single keyword.""" |
| if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': |
| return None |
| else: |
| return value.strip().lower() |
|
|
|
|
| def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: |
| """Preprocess a list of keywords.""" |
| if isinstance(values, str): |
| values = [v.strip() for v in re.split(r'[,\s]', values)] |
| elif isinstance(values, float) and math.isnan(values): |
| values = [] |
| if not isinstance(values, list): |
| logger.debug(f"Unexpected keyword list {values}") |
| values = [str(values)] |
|
|
| kws = [get_keyword(v) for v in values] |
| kw_list = [k for k in kws if k is not None] |
| if len(kw_list) == 0: |
| return None |
| else: |
| return kw_list |
|
|