| import random
|
| from dataclasses import dataclass
|
| from itertools import chain
|
| from pathlib import Path
|
| from random import Random
|
| from typing import Optional, Union
|
|
|
| import numpy as np
|
| import pyarrow.parquet as pq
|
| import torch
|
| import torch.nn.functional as F
|
| from datasets.download.streaming_download_manager import xopen
|
| from huggingface_hub import HfApi
|
| from lightning import LightningDataModule
|
| from torch.distributed import get_rank, get_world_size, is_initialized
|
| from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
| from transformers import AutoTokenizer
|
|
|
| from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
| from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
| from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
| from fish_speech.text.clean import clean_text
|
| from fish_speech.utils import RankedLogger
|
| from fish_speech.utils.braceexpand import braceexpand
|
|
|
| log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
| def split_by_rank_worker(files):
|
|
|
|
|
|
|
| total_devices = 1
|
| if is_initialized():
|
| total_devices = get_world_size()
|
|
|
| worker_info = get_worker_info()
|
| if worker_info is not None:
|
| total_devices *= worker_info.num_workers
|
|
|
| if len(files) < total_devices:
|
|
|
| files = files * (total_devices // len(files) + 1)
|
|
|
|
|
| if is_initialized():
|
| files = files[get_rank() :: get_world_size()]
|
|
|
|
|
| if worker_info is not None:
|
| files = files[worker_info.id :: worker_info.num_workers]
|
|
|
| return files
|
|
|
|
|
| class AutoTextSemanticInstructionDataset(IterableDataset):
|
| """
|
| Auto Augment Dataset by Speaker
|
|
|
| 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
| 2. Automatically normalize the text
|
|
|
| For interactive mode, we use the following format (multiple sequences):
|
| <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
|
|
| For non-interactive mode, we use the following format (one long sequence):
|
| <s> [INST] text [/INST] ... </s>
|
| """
|
|
|
| def __init__(
|
| self,
|
| proto_files: list[str],
|
| seed: int = 42,
|
| interactive_prob: float = 0.5,
|
| max_length: int = 1024,
|
| tokenizer: AutoTokenizer = None,
|
| use_speaker: bool | float = True,
|
| causal: bool = True,
|
| num_codebooks: Optional[int] = None,
|
| skip_text_prob: float = 0.0,
|
| ):
|
| """
|
| Args:
|
| proto_files: proto buf files if using local data
|
| seed: random seed
|
| interactive_prob: probability to use interactive mode
|
| max_length: max length of the text
|
| tokenizer: tokenizer
|
| use_speaker: include speaker information in the prompt
|
| causal: use causal sampling when using local data, disable will lead to random sampling
|
| num_codebooks: number of codebooks, if None, it will be automatically detected
|
| skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
| """
|
|
|
| super().__init__()
|
|
|
| assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
|
|
| self.seed = seed
|
| self.max_length = max_length
|
| self.tokenizer = tokenizer
|
| self.interactive_prob = interactive_prob
|
| self.use_speaker = use_speaker
|
| self.proto_files = proto_files
|
| self.causal = causal
|
| self.num_codebooks = num_codebooks
|
| self.skip_text_prob = skip_text_prob
|
|
|
| self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
| self.groups = None
|
|
|
| def init_mock_data_server(self):
|
| if self.groups is not None:
|
| return
|
|
|
|
|
| expanded_proto_files = []
|
| for filename in self.proto_files:
|
| for i in braceexpand(filename):
|
| i = Path(i)
|
| if i.is_file():
|
| expanded_proto_files.append(i)
|
| elif i.is_dir():
|
| expanded_proto_files.extend(i.rglob("*.proto"))
|
| expanded_proto_files.extend(i.rglob("*.protos"))
|
| else:
|
| raise ValueError(f"{i} is not a file or directory")
|
|
|
| expanded_proto_files = sorted(expanded_proto_files)
|
| Random(self.seed).shuffle(expanded_proto_files)
|
|
|
| self.groups = []
|
| shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
| log.info(
|
| f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
| )
|
|
|
| count = 0
|
| for filename in shard_proto_files:
|
| with open(filename, "rb") as f:
|
| for text_data in read_pb_stream(f):
|
| self.groups.append(text_data)
|
| count += 1
|
|
|
| log.info(f"Read total {count} groups of data")
|
|
|
|
|
| Random(self.seed).shuffle(self.groups)
|
| self.group_weights = [len(i.sentences) for i in self.groups]
|
|
|
| def __iter__(self):
|
| while True:
|
| yield self.augment()
|
|
|
| def tokenize_sentence(self, sentence: str):
|
| sentence = clean_text(sentence)
|
| tokens = self.tokenizer.encode(
|
| f"{sentence}",
|
| max_length=10**6,
|
| add_special_tokens=False,
|
| truncation=False,
|
| )
|
| return sentence, len(tokens)
|
|
|
| def sample_data(self):
|
| if self.groups is None:
|
| self.init_mock_data_server()
|
|
|
|
|
| num_samples = self.max_length // 20
|
|
|
|
|
| group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
|
|
| if self.causal:
|
|
|
| if num_samples >= len(group.sentences):
|
| samples = group.sentences
|
| else:
|
| begin = random.randint(0, len(group.sentences) - num_samples)
|
| samples = group.sentences[begin : begin + num_samples]
|
| else:
|
| samples = random.choices(
|
| group.sentences, k=min(num_samples, len(group.sentences))
|
| )
|
|
|
| return SampledData(
|
| source=group.source,
|
| name=group.name,
|
| samples=samples,
|
| )
|
|
|
| def augment(self):
|
| final_text, final_semantic = [], []
|
| response = self.sample_data()
|
| if len(response.samples) == 0:
|
|
|
| return None
|
|
|
| samples = list(response.samples)
|
| idx = 0
|
| use_interactive = random.random() < self.interactive_prob
|
|
|
| if use_interactive is False:
|
|
|
| a = torch.tensor([0], dtype=torch.float32)
|
| torch.nn.init.trunc_normal_(
|
| a,
|
| mean=self.max_length // 2,
|
| std=self.max_length // 4,
|
| a=10,
|
| b=self.max_length,
|
| )
|
| remaining_tokens = a.long().item() - 4
|
| else:
|
| remaining_tokens = self.max_length
|
|
|
|
|
| if isinstance(self.use_speaker, float):
|
| use_speaker = random.random() < self.use_speaker
|
| else:
|
| use_speaker = self.use_speaker
|
|
|
| all_tokens, all_labels = [], []
|
| while remaining_tokens > 0 and len(samples) > 0:
|
| sentence = samples.pop(0)
|
|
|
| text = random.choice(sentence.texts)
|
| text, length = self.tokenize_sentence(text)
|
| remaining_tokens -= length + len(sentence.semantics[0].values)
|
|
|
| if use_interactive is False:
|
| final_text.append(text)
|
| final_semantic.append(sentence.semantics)
|
| else:
|
|
|
|
|
| tokens, labels = self.pack_sentences(
|
| sentences=[text],
|
| semantics=[sentence.semantics],
|
| speaker=response.name if use_speaker else None,
|
| skip_text=random.random() < self.skip_text_prob,
|
| )
|
|
|
| all_tokens.append(tokens)
|
| all_labels.append(labels)
|
|
|
| idx += 1
|
|
|
| if use_interactive is False:
|
| tokens, labels = self.pack_sentences(
|
| final_text,
|
| semantics=final_semantic,
|
| speaker=response.name if use_speaker else None,
|
| )
|
| all_tokens.append(tokens)
|
| all_labels.append(labels)
|
|
|
| tokens = torch.cat(all_tokens, dim=1)
|
| labels = torch.cat(all_labels, dim=1)
|
|
|
|
|
| assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
|
| data = {"tokens": tokens, "labels": labels}
|
|
|
| return data
|
|
|
| def pack_sentences(
|
| self,
|
| sentences: list[str],
|
| semantics: list,
|
| speaker: Optional[str] = None,
|
| skip_text: bool = False,
|
| ):
|
| if speaker is None:
|
| speaker = "assistant"
|
|
|
| cated_sentences = " ".join(sentences)
|
| if skip_text:
|
| cated_sentences = "<|skip_text|>"
|
|
|
| final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
|
| final_text = final_text + f"<|im_start|>{speaker}\n"
|
|
|
| encoded = self.tokenizer.encode(
|
| final_text,
|
| add_special_tokens=False,
|
| truncation=False,
|
| max_length=10**6,
|
| )
|
| semantic_length = sum([len(i[0].values) for i in semantics])
|
| prompt_length = len(encoded)
|
| num_codebooks = (
|
| len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
| )
|
|
|
|
|
| tokens = (
|
| encoded
|
| + [self.semantic_token_id] * semantic_length
|
| + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
| )
|
|
|
|
|
| codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
|
| for segment in semantics:
|
| for book_idx, book in zip(range(num_codebooks), segment):
|
| for j in book.values:
|
| codes[book_idx].append(int(j) + 1)
|
|
|
| for book in codes:
|
| book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
|
|
|
| tokens = [tokens] + codes
|
|
|
| tokens = torch.tensor(tokens, dtype=torch.long)
|
| labels = tokens.clone()
|
|
|
| if skip_text:
|
|
|
| torch.fill_(labels, -100)
|
| return tokens, labels
|
|
|
|
|
|
|
| labels[1:, :prompt_length] = -100
|
|
|
| tokens = tokens[:, :-1]
|
| labels = labels[:, 1:]
|
|
|
|
|
| assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
|
| assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
| return tokens, labels
|
|
|
|
|
| @dataclass
|
| class TextDataCollator:
|
| tokenizer: AutoTokenizer
|
| max_length: int = 1024
|
|
|
| def __call__(self, examples):
|
| if "negative_tokens" in examples:
|
| positive_examples = []
|
| negative_examples = []
|
|
|
| for i in examples:
|
| positive_examples.append(
|
| {
|
| "tokens": i["tokens"],
|
| "labels": i["labels"],
|
| }
|
| )
|
| negative_examples.append(
|
| {
|
| "tokens": i["negative_tokens"],
|
| "labels": i["negative_labels"],
|
| }
|
| )
|
|
|
| examples = positive_examples + negative_examples
|
|
|
| return self.batchify(examples)
|
|
|
| def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
|
| tokens, attention_masks, labels = [], [], []
|
|
|
|
|
| max_tokens_length = 0
|
| for example in examples:
|
| max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
|
| max_tokens_length = min(max_tokens_length, self.max_length)
|
|
|
| for example in examples:
|
| _tokens = example[tokens_key][:, :max_tokens_length]
|
| _labels = example[labels_key][:, :max_tokens_length]
|
| _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
|
| tokens_length = _tokens.size(1)
|
| _attention_mask[:tokens_length] = False
|
|
|
| assert tokens_length == _labels.size(
|
| 1
|
| ), f"{tokens_length} != {_labels.size(1)}"
|
|
|
| if tokens_length < max_tokens_length:
|
| _tokens = F.pad(
|
| _tokens,
|
| (0, max_tokens_length - tokens_length),
|
| value=self.tokenizer.eos_token_id,
|
| )
|
| _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
|
| _labels = F.pad(
|
| _labels, (0, max_tokens_length - _labels.size(1)), value=-100
|
| )
|
|
|
| tokens.append(_tokens)
|
| attention_masks.append(_attention_mask)
|
| labels.append(_labels)
|
|
|
| tokens = torch.stack(tokens, dim=0)
|
| attention_masks = torch.stack(attention_masks, dim=0)
|
| labels = torch.stack(labels, dim=0)
|
|
|
| return {
|
| "inputs": tokens,
|
| "attention_masks": attention_masks,
|
| "labels": labels,
|
| }
|
|
|
|
|
| class InterleaveDataset(IterableDataset):
|
| def __init__(
|
| self,
|
| datasets: list[IterableDataset],
|
| probabilities: list[float],
|
| seed: int = 42,
|
| ):
|
| super().__init__()
|
|
|
| self.datasets = datasets
|
| self.probabilities = probabilities
|
| self.seed = seed
|
|
|
| def __iter__(self):
|
| rng = np.random.default_rng(self.seed)
|
| dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
|
|
| while True:
|
|
|
| dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
| dataset_iterator = dataset_iterators[dataset_idx]
|
|
|
| try:
|
| yield next(dataset_iterator)
|
| except StopIteration:
|
|
|
| dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
| yield next(dataset_iterators[dataset_idx])
|
|
|
|
|
| class SemanticDataModule(LightningDataModule):
|
| def __init__(
|
| self,
|
| train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
| val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
| batch_size: int = 32,
|
| tokenizer: AutoTokenizer = None,
|
| max_length: int = 1024,
|
| num_workers: int = 4,
|
| ):
|
| super().__init__()
|
|
|
| self.train_dataset = train_dataset
|
| self.val_dataset = val_dataset
|
| self.batch_size = batch_size
|
| self.tokenizer = tokenizer
|
| self.max_length = max_length
|
| self.num_workers = num_workers
|
|
|
| def train_dataloader(self):
|
| return DataLoader(
|
| self.train_dataset,
|
| batch_size=self.batch_size,
|
| collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
| num_workers=self.num_workers,
|
| persistent_workers=True,
|
| )
|
|
|
| def val_dataloader(self):
|
| return DataLoader(
|
| self.val_dataset,
|
| batch_size=self.batch_size,
|
| collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
| num_workers=self.num_workers,
|
| persistent_workers=True,
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| from tqdm import tqdm
|
|
|
| ds = AutoTextSemanticInstructionDataset(
|
| ["data/protos"],
|
| tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
| use_speaker=False,
|
| interactive_prob=1.0,
|
| skip_text_prob=0.5,
|
| )
|
|
|
| for i in ds:
|
| print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
|
|
|
|
| break
|
|
|