| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | import random |
| | from collections import defaultdict |
| | from concurrent.futures import ThreadPoolExecutor |
| | from typing import Tuple, Type |
| |
|
| | from lhotse import CutSet |
| | from lhotse.dataset.collation import collate_features |
| | from lhotse.dataset.input_strategies import ( |
| | ExecutorType, |
| | PrecomputedFeatures, |
| | _get_executor, |
| | ) |
| | from lhotse.utils import fastcopy |
| |
|
| |
|
| | class PromptedFeatures: |
| | def __init__(self, prompts, features): |
| | self.prompts = prompts |
| | self.features = features |
| |
|
| | def to(self, device): |
| | return PromptedFeatures( |
| | self.prompts.to(device), self.features.to(device) |
| | ) |
| |
|
| | def sum(self): |
| | return self.features.sum() |
| |
|
| | @property |
| | def ndim(self): |
| | return self.features.ndim |
| |
|
| | @property |
| | def data(self): |
| | return (self.prompts, self.features) |
| |
|
| |
|
| | class PromptedPrecomputedFeatures(PrecomputedFeatures): |
| | |
| | def __init__( |
| | self, |
| | dataset: str, |
| | cuts: CutSet, |
| | num_workers: int = 0, |
| | executor_type: Type[ExecutorType] = ThreadPoolExecutor, |
| | ) -> None: |
| | super().__init__(num_workers, executor_type) |
| | self.utt2neighbors = self._create_utt2neighbors(dataset, cuts) |
| |
|
| | def __call__( |
| | self, cuts: CutSet |
| | ) -> Tuple[PromptedFeatures, PromptedFeatures]: |
| | features, features_lens = self._collate_features(cuts) |
| | prompts, prompts_lens = self._collate_prompts(cuts) |
| | return PromptedFeatures(prompts, features), PromptedFeatures(prompts_lens, features_lens) |
| |
|
| | def _create_utt2neighbors(self, dataset, cuts): |
| | utt2neighbors = defaultdict(lambda: []) |
| | utt2cut = {cut.id: cut for cut in cuts} |
| | if dataset.lower() == "libritts": |
| | self._process_libritts_dataset(utt2neighbors, utt2cut, cuts) |
| | elif dataset.lower() == "ljspeech": |
| | self._process_ljspeech_dataset(utt2neighbors, utt2cut, cuts) |
| | else: |
| | raise ValueError("Unsupported dataset") |
| | return utt2neighbors |
| |
|
| | def _process_libritts_dataset(self, utt2neighbors, utt2cut, cuts): |
| | speaker2utts = defaultdict(lambda: []) |
| | for cut in cuts: |
| | speaker = cut.supervisions[0].speaker |
| | speaker2utts[speaker].append(cut.id) |
| |
|
| | for spk, uttids in speaker2utts.items(): |
| | sorted_uttids = sorted(uttids) |
| | if len(sorted_uttids) == 1: |
| | utt2neighbors[sorted_uttids[0]].append(utt2cut[sorted_uttids[0]]) |
| | continue |
| |
|
| | utt2prevutt = dict(zip(sorted_uttids, [sorted_uttids[1]] + sorted_uttids[:-1])) |
| | utt2postutt = dict(zip(sorted_uttids[:-1], sorted_uttids[1:])) |
| | for utt in sorted_uttids: |
| | if utt in utt2prevutt: |
| | utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]]) |
| | if utt in utt2postutt: |
| | utt2neighbors[utt].append(utt2cut[utt2postutt[utt]]) |
| |
|
| | def _process_ljspeech_dataset(self, utt2neighbors, utt2cut, cuts): |
| | uttids = [cut.id for cut in cuts] |
| | if len(uttids) == 1: |
| | utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) |
| | return |
| |
|
| | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) |
| | utt2postutt = dict(zip(uttids[:-1], uttids[1:])) |
| | for utt in uttids: |
| | prevutt, postutt = utt2prevutt.get(utt), utt2postutt.get(utt) |
| | if prevutt and utt[:5] == prevutt[:5]: |
| | utt2neighbors[utt].append(utt2cut[prevutt]) |
| | if postutt and utt[:5] == postutt[:5]: |
| | utt2neighbors[utt].append(utt2cut[postutt]) |
| |
|
| | def _collate_features(self, cuts): |
| | return collate_features( |
| | cuts, executor=_get_executor(self.num_workers, executor_type=self._executor_type) |
| | ) |
| |
|
| | def _collate_prompts(self, cuts): |
| | prompts_cuts = [] |
| | for k, cut in enumerate(cuts): |
| | prompts_cut = random.choice(self.utt2neighbors[cut.id]) |
| | prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}")) |
| |
|
| | mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0]) |
| | prompts_cuts = CutSet( |
| | cuts={k: cut for k, cut in enumerate(prompts_cuts)} |
| | ).truncate(max_duration=mini_duration, offset_type="random", preserve_id=False) |
| |
|
| | return collate_features( |
| | prompts_cuts, executor=_get_executor(self.num_workers, executor_type=self._executor_type) |
| | ) |
| |
|