| | |
| | |
| | |
| | |
| | |
| | """Base classes for the datasets that also provide non-audio metadata, |
| | e.g. description, text transcription etc. |
| | """ |
| | from dataclasses import dataclass |
| | import logging |
| | import math |
| | import re |
| | import typing as tp |
| |
|
| | import torch |
| |
|
| | from .audio_dataset import AudioDataset, AudioMeta |
| | from ..environment import AudioCraftEnvironment |
| | from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _clusterify_meta(meta: AudioMeta) -> AudioMeta: |
| | """Monkey-patch meta to match cluster specificities.""" |
| | meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) |
| | if meta.info_path is not None: |
| | meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) |
| | return meta |
| |
|
| |
|
| | def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: |
| | """Monkey-patch all meta to match cluster specificities.""" |
| | return [_clusterify_meta(m) for m in meta] |
| |
|
| |
|
| | @dataclass |
| | class AudioInfo(SegmentWithAttributes): |
| | """Dummy SegmentInfo with empty attributes. |
| | |
| | The InfoAudioDataset is expected to return metadata that inherits |
| | from SegmentWithAttributes class and can return conditioning attributes. |
| | |
| | This basically guarantees all datasets will be compatible with current |
| | solver that contain conditioners requiring this. |
| | """ |
| | audio_tokens: tp.Optional[torch.Tensor] = None |
| |
|
| | def to_condition_attributes(self) -> ConditioningAttributes: |
| | return ConditioningAttributes() |
| |
|
| |
|
| | class InfoAudioDataset(AudioDataset): |
| | """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. |
| | |
| | See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. |
| | """ |
| | def __init__(self, meta: tp.List[AudioMeta], **kwargs): |
| | super().__init__(clusterify_all_meta(meta), **kwargs) |
| |
|
| | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: |
| | if not self.return_info: |
| | wav = super().__getitem__(index) |
| | assert isinstance(wav, torch.Tensor) |
| | return wav |
| | wav, meta = super().__getitem__(index) |
| | return wav, AudioInfo(**meta.to_dict()) |
| |
|
| |
|
| | def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: |
| | """Preprocess a single keyword or possible a list of keywords.""" |
| | if isinstance(value, list): |
| | return get_keyword_list(value) |
| | else: |
| | return get_keyword(value) |
| |
|
| |
|
| | def get_string(value: tp.Optional[str]) -> tp.Optional[str]: |
| | """Preprocess a single keyword.""" |
| | if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': |
| | return None |
| | else: |
| | return value.strip() |
| |
|
| |
|
| | def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: |
| | """Preprocess a single keyword.""" |
| | if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': |
| | return None |
| | else: |
| | return value.strip().lower() |
| |
|
| |
|
| | def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: |
| | """Preprocess a list of keywords.""" |
| | if isinstance(values, str): |
| | values = [v.strip() for v in re.split(r'[,\s]', values)] |
| | elif isinstance(values, float) and math.isnan(values): |
| | values = [] |
| | if not isinstance(values, list): |
| | logger.debug(f"Unexpected keyword list {values}") |
| | values = [str(values)] |
| |
|
| | kws = [get_keyword(v) for v in values] |
| | kw_list = [k for k in kws if k is not None] |
| | if len(kw_list) == 0: |
| | return None |
| | else: |
| | return kw_list |
| |
|