| |
|
| | import typing as tp |
| | import torch |
| | import torch.nn as nn |
| | from dataclasses import dataclass, field, fields |
| | from itertools import chain |
| | import warnings |
| | import torch.nn.functional as F |
| | from torch.nn.utils.rnn import pad_sequence |
| | from codeclm.utils.utils import length_to_mask, collate |
| | from codeclm.modules.streaming import StreamingModule |
| | from collections import defaultdict |
| | from copy import deepcopy |
| | ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] |
| |
|
| | |
| | |
| | |
| | class AudioCondition(tp.NamedTuple): |
| | wav: torch.Tensor |
| | length: torch.Tensor |
| | sample_rate: tp.List[int] |
| | path: tp.List[tp.Optional[str]] = [] |
| | seek_time: tp.List[tp.Optional[float]] = [] |
| | |
| | @dataclass |
| | class ConditioningAttributes: |
| | text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) |
| | audio: tp.Dict[str, AudioCondition] = field(default_factory=dict) |
| |
|
| | def __getitem__(self, item): |
| | return getattr(self, item) |
| |
|
| | @property |
| | def text_attributes(self): |
| | return self.text.keys() |
| |
|
| | @property |
| | def audio_attributes(self): |
| | return self.audio.keys() |
| |
|
| | @property |
| | def attributes(self): |
| | return { |
| | "text": self.text_attributes, |
| | "audio": self.audio_attributes, |
| | } |
| |
|
| | def to_flat_dict(self): |
| | return { |
| | **{f"text.{k}": v for k, v in self.text.items()}, |
| | **{f"audio.{k}": v for k, v in self.audio.items()}, |
| | } |
| |
|
| | @classmethod |
| | def from_flat_dict(cls, x): |
| | out = cls() |
| | for k, v in x.items(): |
| | kind, att = k.split(".") |
| | out[kind][att] = v |
| | return out |
| |
|
| | |
| | |
| | |
| |
|
| | class BaseConditioner(nn.Module): |
| | """Base model for all conditioner modules. |
| | We allow the output dim to be different than the hidden dim for two reasons: |
| | 1) keep our LUTs small when the vocab is large; |
| | 2) make all condition dims consistent. |
| | |
| | Args: |
| | dim (int): Hidden dim of the model. |
| | output_dim (int): Output dim of the conditioner. |
| | """ |
| | def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=0): |
| | super().__init__() |
| | self.dim = dim |
| | self.output_dim = output_dim |
| | if input_token: |
| | self.output_proj = nn.Embedding(dim, output_dim, padding_idx) |
| | else: |
| | self.output_proj = nn.Linear(dim, output_dim) |
| |
|
| | def tokenize(self, *args, **kwargs) -> tp.Any: |
| | """Should be any part of the processing that will lead to a synchronization |
| | point, e.g. BPE tokenization with transfer to the GPU. |
| | |
| | The returned value will be saved and return later when calling forward(). |
| | """ |
| | raise NotImplementedError() |
| |
|
| | def forward(self, inputs: tp.Any) -> ConditionType: |
| | """Gets input that should be used as conditioning (e.g, genre, description or a waveform). |
| | Outputs a ConditionType, after the input data was embedded as a dense vector. |
| | |
| | Returns: |
| | ConditionType: |
| | - A tensor of size [B, T, D] where B is the batch size, T is the length of the |
| | output embedding and D is the dimension of the embedding. |
| | - And a mask indicating where the padding tokens. |
| | """ |
| | raise NotImplementedError() |
| | |
| | class TextConditioner(BaseConditioner): |
| | ... |
| |
|
| |
|
| | class PhonemeTokenizerConditioner(TextConditioner): |
| | def __init__(self, |
| | output_dim: int, |
| | vocab_list, |
| | max_len = 600, |
| | max_sentence_per_structure = 50, |
| | structure_tokens=None, |
| | structure_split_tokens=[','], |
| | sentence_split_tokens=['.'], |
| | mode='sum', |
| | structure_output_dim = 64, |
| | sentence_output_dim = 64, |
| | max_duration = 120, |
| | ): |
| | |
| | self.vocab_list = vocab_list |
| | self.max_len = max_len |
| | self.mode = mode |
| | self.max_sentence_per_structure = max_sentence_per_structure |
| | voc_size = len(self.vocab_list) |
| |
|
| | if structure_tokens is None: |
| | structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']'] |
| | self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list] |
| | self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens] |
| | self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens] |
| |
|
| | |
| | |
| | if mode == 'sum': |
| | content_output_dim = output_dim |
| | sentence_output_dim = output_dim |
| | structure_output_dim = output_dim |
| | else: |
| | raise NotImplementedError("concat 模式还未实现") |
| | |
| | |
| | super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0) |
| | self.special_emb = nn.Embedding(voc_size, structure_output_dim, padding_idx=0) |
| | |
| | self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False) |
| |
|
| | |
| | self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim) |
| | self.sentence_reidx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim) |
| |
|
| | print("max_len", self.max_len) |
| | print(self.structure_token_ids) |
| | |
| | self.resolution = max_duration / max_len |
| | print(self.__class__, f"resolution = {self.resolution}") |
| | |
| | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: |
| | inputs = [] |
| | for xx in x: |
| | xx = '' if xx is None else xx |
| | vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list] |
| | inputs.append(torch.tensor(vocab_id).long()) |
| | return inputs |
| | |
| | |
| | def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType: |
| | """ |
| | Encode token_id into three types of embeddings: |
| | 1) content embedding: phoneme only (or meaningful contents to be sung out) |
| | 2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,) |
| | The two above share the same embedding layer, can be changed to separate embedding layers. |
| | 3) sentence_idx embedding (per structure): |
| | """ |
| | embeds_batch = [] |
| | for b in range(len(batch_tokens)): |
| | tokens = batch_tokens[b] |
| | content_tokens = torch.zeros_like(tokens) |
| | special_tokens = torch.zeros_like(tokens) |
| | sentence_idx_in_structure_tokens = torch.zeros_like(tokens) |
| | sentence_reidx_in_structure_tokens = torch.zeros_like(tokens) |
| |
|
| | current_sentence_in_structure_idx = 1 |
| | current_structure = 0 |
| | for i in range(tokens.shape[-1]): |
| | token = tokens[i] |
| | if token in self.structure_token_ids: |
| | |
| | special_tokens[i] = token |
| | content_tokens[i] = token |
| | current_structure = token |
| | current_sentence_in_structure_idx = 1 |
| | sentence_idx_in_structure_tokens[i] = 0 |
| |
|
| | elif token in self.sentence_split_token_ids: |
| | |
| | |
| | special_tokens[i] = current_structure |
| | content_tokens[i] = token |
| | sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1) |
| | current_sentence_in_structure_idx += 1 |
| |
|
| | elif token in self.structure_split_token_ids: |
| | |
| | |
| | content_tokens[i] = token |
| | special_tokens[i] = current_structure |
| | sentence_idx_in_structure_tokens[i] = sentence_idx_in_structure_tokens[i-1] |
| | else: |
| | content_tokens[i] = token |
| | special_tokens[i] = current_structure |
| | sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1) |
| | |
| | current_sentence_num = sentence_idx_in_structure_tokens[-1] |
| | for i in range(tokens.shape[-1]-1,-1,-1): |
| | if current_sentence_num != 0: |
| | sentence_reidx_in_structure_tokens[i] = min(current_sentence_num + 1 - sentence_idx_in_structure_tokens[i], self.max_sentence_per_structure - 1) |
| | if sentence_idx_in_structure_tokens[i] == 0 and i > 0: |
| | current_sentence_num = sentence_idx_in_structure_tokens[i-1] |
| |
|
| | |
| | |
| | |
| | device = self.output_proj.weight.device |
| |
|
| | |
| | content_embeds = self.output_proj(content_tokens.to(device)) |
| | structure_embeds = self.output_proj(special_tokens.to(device)) |
| | |
| | sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + self.sentence_reidx_in_structure_emb(sentence_reidx_in_structure_tokens.to(device)) |
| |
|
| | if self.mode == 'sum': |
| | embeds = content_embeds + structure_embeds + sentence_idx_embeds |
| | else: |
| | embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) |
| | embeds_batch.append(embeds) |
| |
|
| | |
| | if self.max_len is not None: |
| | max_len = self.max_len |
| | else: |
| | max_len = max([e.shape[0] for e in embeds_batch]) |
| | embeds, mask = self.pad_2d_tensor(embeds_batch, max_len) |
| | |
| | return embeds, embeds, mask |
| | |
| | |
| | def pad_2d_tensor(self, xs, max_len): |
| | new_tensor = [] |
| | new_mask = [] |
| | for x in xs: |
| | seq_len, dim = x.size() |
| | pad_len = max_len - seq_len |
| |
|
| | if pad_len > 0: |
| | pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) |
| | padded_tensor = torch.cat([x, pad_tensor], dim=0) |
| | mask = torch.cat((torch.ones_like(x[:, 0]), |
| | torch.zeros_like(pad_tensor[:, 0])), 0) |
| | elif pad_len < 0: |
| | padded_tensor = x[:max_len] |
| | mask = torch.ones_like(padded_tensor[:, 0]) |
| | else: |
| | padded_tensor = x |
| | mask = torch.ones_like(x[:, 0]) |
| |
|
| | new_tensor.append(padded_tensor) |
| | new_mask.append(mask) |
| | |
| | return torch.stack(new_tensor, 0), torch.stack(new_mask, 0) |
| |
|
| |
|
| | class QwTokenizerConditioner(TextConditioner): |
| | def __init__(self, output_dim: int, |
| | token_path = "", |
| | max_len = 300, |
| | add_token_list=[]): |
| | from transformers import Qwen2Tokenizer |
| | self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) |
| | if add_token_list != []: |
| | self.text_tokenizer.add_tokens(add_token_list, special_tokens=True) |
| | voc_size = len(self.text_tokenizer.get_vocab()) |
| | |
| | super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643) |
| | self.max_len = max_len |
| | self.padding_idx =' <|endoftext|>' |
| |
|
| | vocab = self.text_tokenizer.get_vocab() |
| | |
| | struct_tokens = [i for i in add_token_list if i[0]=='[' and i[-1]==']'] |
| | self.struct_token_ids = [vocab[i] for i in struct_tokens] |
| | self.pad_token_idx = 151643 |
| | |
| | self.structure_emb = nn.Embedding(200, output_dim, padding_idx=0) |
| | |
| | print("all structure tokens: ", {self.text_tokenizer.convert_ids_to_tokens(i):i for i in self.struct_token_ids}) |
| | |
| | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: |
| | x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x] |
| | |
| | inputs = self.text_tokenizer(x, return_tensors="pt", padding=True) |
| | return inputs |
| |
|
| | def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: |
| | """ |
| | Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that |
| | belong to these structures accordingly, |
| | Then delete or keep these structure embeddings. |
| | """ |
| | mask = inputs['attention_mask'] |
| | tokens = inputs['input_ids'] |
| | B = tokens.shape[0] |
| | is_sp_embed = torch.any(torch.stack([tokens == i for i in self.struct_token_ids], dim=-1),dim=-1) |
| |
|
| | tp_cover_range = torch.zeros_like(tokens) |
| | for b, is_sp in enumerate(is_sp_embed): |
| | sp_list = torch.where(is_sp)[0].tolist() |
| | sp_list.append(mask[b].sum()) |
| | for i, st in enumerate(sp_list[:-1]): |
| | tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645 |
| |
|
| | if self.max_len is not None: |
| | if inputs['input_ids'].shape[-1] > self.max_len: |
| | warnings.warn(f"Max len limit ({self.max_len}) Exceed! \ |
| | {[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!") |
| | tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device) |
| | mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device) |
| | tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device) |
| | device = self.output_proj.weight.device |
| | content_embeds = self.output_proj(tokens.to(device)) |
| | structure_embeds = self.structure_emb(tp_cover_range.to(device)) |
| |
|
| | embeds = content_embeds + structure_embeds |
| | return embeds, embeds, mask |
| | |
| | def pad_2d_tensor(self, x, max_len, pad_id): |
| | batch_size, seq_len = x.size() |
| | pad_len = max_len - seq_len |
| |
|
| | if pad_len > 0: |
| | pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) |
| | padded_tensor = torch.cat([x, pad_tensor], dim=1) |
| | elif pad_len < 0: |
| | padded_tensor = x[:, :max_len] |
| | else: |
| | padded_tensor = x |
| |
|
| | return padded_tensor |
| |
|
| |
|
| | class QwTextConditioner(TextConditioner): |
| | def __init__(self, output_dim: int, |
| | token_path = "", |
| | max_len = 300): |
| | |
| | from transformers import Qwen2Tokenizer |
| | self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) |
| | voc_size = len(self.text_tokenizer.get_vocab()) |
| | |
| | super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643) |
| | |
| | self.max_len = max_len |
| | |
| | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: |
| | x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x] |
| | inputs = self.text_tokenizer(x, return_tensors="pt", padding=True) |
| | return inputs |
| |
|
| | def forward(self, inputs: tp.Dict[str, torch.Tensor], structure_dur = None) -> ConditionType: |
| | """ |
| | Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that |
| | belong to these structures accordingly, |
| | Then delete or keep these structure embeddings. |
| | """ |
| | mask = inputs['attention_mask'] |
| | tokens = inputs['input_ids'] |
| |
|
| | if self.max_len is not None: |
| | if inputs['input_ids'].shape[-1] > self.max_len: |
| | warnings.warn(f"Max len limit ({self.max_len}) Exceed! \ |
| | {[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!") |
| | tokens = self.pad_2d_tensor(tokens, self.max_len, 151643).to(self.output_proj.weight.device) |
| | mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device) |
| | |
| | embeds = self.output_proj(tokens) |
| | return embeds, embeds, mask |
| | |
| | def pad_2d_tensor(self, x, max_len, pad_id): |
| | batch_size, seq_len = x.size() |
| | pad_len = max_len - seq_len |
| |
|
| | if pad_len > 0: |
| | pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) |
| | padded_tensor = torch.cat([x, pad_tensor], dim=1) |
| | elif pad_len < 0: |
| | padded_tensor = x[:, :max_len] |
| | else: |
| | padded_tensor = x |
| |
|
| | return padded_tensor |
| |
|
| |
|
| | class AudioConditioner(BaseConditioner): |
| | ... |
| | |
| | class QuantizedEmbeddingConditioner(AudioConditioner): |
| | def __init__(self, dim: int, |
| | code_size: int, |
| | code_depth: int, |
| | max_len: int, |
| | **kwargs): |
| | super().__init__(dim, dim, input_token=True) |
| | self.code_depth = code_depth |
| | |
| | self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)]) |
| | |
| | self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True) |
| | self.layer2_EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True) |
| | self.output_proj = None |
| | self.max_len = max_len |
| | self.vocab_size = code_size |
| |
|
| | def tokenize(self, x: AudioCondition) -> AudioCondition: |
| | """no extra ops""" |
| | |
| | |
| | return x |
| |
|
| | def forward(self, x: AudioCondition): |
| | wav, lengths, *_ = x |
| | B = wav.shape[0] |
| | wav = wav.reshape(B, self.code_depth, -1).long() |
| | if wav.shape[2] < self.max_len - 1: |
| | wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1) |
| | else: |
| | wav = wav[:, :, :self.max_len-1] |
| | embeds1 = self.emb[0](wav[:, 0]) |
| | embeds1 = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1), |
| | embeds1), dim=1) |
| | embeds2 = sum([self.emb[k](wav[:, k]) for k in range(1, self.code_depth)]) |
| | embeds2 = torch.cat((self.layer2_EOT_emb.unsqueeze(0).repeat(B, 1, 1), |
| | embeds2), dim=1) |
| | lengths = lengths + 1 |
| | lengths = torch.clamp(lengths, max=self.max_len) |
| |
|
| | if lengths is not None: |
| | mask = length_to_mask(lengths, max_len=embeds1.shape[1]).int() |
| | else: |
| | mask = torch.ones((B, self.code_depth), device=embeds1.device, dtype=torch.int) |
| | return embeds1, embeds2, mask |
| |
|
| |
|
| | |
| | |
| | |
| | class ConditionerProvider(nn.Module): |
| | """Prepare and provide conditions given all the supported conditioners. |
| | |
| | Args: |
| | conditioners (dict): Dictionary of conditioners. |
| | device (torch.device or str, optional): Device for conditioners and output condition types. |
| | """ |
| | def __init__(self, conditioners: tp.Dict[str, BaseConditioner]): |
| | super().__init__() |
| | self.conditioners = nn.ModuleDict(conditioners) |
| |
|
| | @property |
| | def text_conditions(self): |
| | return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] |
| |
|
| | @property |
| | def audio_conditions(self): |
| | return [k for k, v in self.conditioners.items() if isinstance(v, AudioConditioner)] |
| |
|
| | @property |
| | def has_audio_condition(self): |
| | return len(self.audio_conditions) > 0 |
| |
|
| | def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: |
| | """Match attributes/audios with existing conditioners in self, and compute tokenize them accordingly. |
| | This should be called before starting any real GPU work to avoid synchronization points. |
| | This will return a dict matching conditioner names to their arbitrary tokenized representations. |
| | |
| | Args: |
| | inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing |
| | text and audio conditions. |
| | """ |
| | assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( |
| | "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", |
| | f" but types were {set([type(x) for x in inputs])}") |
| |
|
| | output = {} |
| | text = self._collate_text(inputs) |
| | audios = self._collate_audios(inputs) |
| |
|
| | assert set(text.keys() | audios.keys()).issubset(set(self.conditioners.keys())), ( |
| | f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", |
| | f"got {text.keys(), audios.keys()}") |
| |
|
| | for attribute, batch in chain(text.items(), audios.items()): |
| | output[attribute] = self.conditioners[attribute].tokenize(batch) |
| | return output |
| |
|
| | def forward(self, tokenized: tp.Dict[str, tp.Any], structure_dur = None) -> tp.Dict[str, ConditionType]: |
| | """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. |
| | The output is for example: |
| | { |
| | "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), |
| | "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), |
| | ... |
| | } |
| | |
| | Args: |
| | tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. |
| | """ |
| | output = {} |
| | for attribute, inputs in tokenized.items(): |
| | if attribute == 'description' and structure_dur is not None: |
| | condition1, condition2, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur) |
| | else: |
| | condition1, condition2, mask = self.conditioners[attribute](inputs) |
| | output[attribute] = (condition1, condition2, mask) |
| | return output |
| |
|
| | def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: |
| | """Given a list of ConditioningAttributes objects, compile a dictionary where the keys |
| | are the attributes and the values are the aggregated input per attribute. |
| | For example: |
| | Input: |
| | [ |
| | ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), |
| | ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, audio=...), |
| | ] |
| | Output: |
| | { |
| | "genre": ["Rock", "Hip-hop"], |
| | "description": ["A rock song with a guitar solo", "A hip-hop verse"] |
| | } |
| | |
| | Args: |
| | samples (list of ConditioningAttributes): List of ConditioningAttributes samples. |
| | Returns: |
| | dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. |
| | """ |
| | out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) |
| | texts = [x.text for x in samples] |
| | for text in texts: |
| | for condition in self.text_conditions: |
| | out[condition].append(text[condition]) |
| | return out |
| |
|
| | def _collate_audios(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, AudioCondition]: |
| | """Generate a dict where the keys are attributes by which we fetch similar audios, |
| | and the values are Tensors of audios according to said attributes. |
| | |
| | *Note*: by the time the samples reach this function, each sample should have some audios |
| | inside the "audio" attribute. It should be either: |
| | 1. A real audio |
| | 2. A null audio due to the sample having no similar audios (nullified by the dataset) |
| | 3. A null audio due to it being dropped in a dropout module (nullified by dropout) |
| | |
| | Args: |
| | samples (list of ConditioningAttributes): List of ConditioningAttributes samples. |
| | Returns: |
| | dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. |
| | """ |
| | |
| | wavs = defaultdict(list) |
| | lengths = defaultdict(list) |
| | sample_rates = defaultdict(list) |
| | paths = defaultdict(list) |
| | seek_times = defaultdict(list) |
| | out: tp.Dict[str, AudioCondition] = {} |
| |
|
| | for sample in samples: |
| | for attribute in self.audio_conditions: |
| | wav, length, sample_rate, path, seek_time = sample.audio[attribute] |
| | assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" |
| | assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" |
| | wavs[attribute].append(wav.flatten()) |
| | lengths[attribute].append(length) |
| | sample_rates[attribute].extend(sample_rate) |
| | paths[attribute].extend(path) |
| | seek_times[attribute].extend(seek_time) |
| |
|
| | |
| | for attribute in self.audio_conditions: |
| | stacked_wav, _ = collate(wavs[attribute], dim=0) |
| | out[attribute] = AudioCondition( |
| | stacked_wav.unsqueeze(1), |
| | torch.cat(lengths[attribute]), sample_rates[attribute], |
| | paths[attribute], seek_times[attribute]) |
| |
|
| | return out |
| |
|
| |
|
| | class ConditionFuser(StreamingModule): |
| | """Condition fuser handles the logic to combine the different conditions |
| | to the actual model input. |
| | |
| | Args: |
| | fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse |
| | each condition. For example: |
| | { |
| | "prepend": ["description"], |
| | "sum": ["genre", "bpm"], |
| | } |
| | """ |
| | FUSING_METHODS = ["sum", "prepend"] |
| | |
| | def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]]): |
| | super().__init__() |
| | assert all([k in self.FUSING_METHODS for k in fuse2cond.keys()] |
| | ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" |
| | self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond |
| | self.cond2fuse: tp.Dict[str, str] = {} |
| | for fuse_method, conditions in fuse2cond.items(): |
| | for condition in conditions: |
| | self.cond2fuse[condition] = fuse_method |
| | |
| | def forward( |
| | self, |
| | input1: torch.Tensor, |
| | input2: torch.Tensor, |
| | conditions: tp.Dict[str, ConditionType] |
| | ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| | """Fuse the conditions to the provided model input. |
| | |
| | Args: |
| | input (torch.Tensor): Transformer input. |
| | conditions (dict[str, ConditionType]): Dict of conditions. |
| | Returns: |
| | tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input |
| | after the conditions have been fused. The second output tensor is the tensor |
| | used for cross-attention or None if no cross attention inputs exist. |
| | """ |
| | |
| | B, T, _ = input1.shape |
| |
|
| | if 'offsets' in self._streaming_state: |
| | first_step = False |
| | offsets = self._streaming_state['offsets'] |
| | else: |
| | first_step = True |
| | offsets = torch.zeros(input1.shape[0], dtype=torch.long, device=input1.device) |
| |
|
| | assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ |
| | f"given conditions contain unknown attributes for fuser, " \ |
| | f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" |
| | |
| | |
| | |
| | |
| | fused_input_1 = input1 |
| | fused_input_2 = input2 |
| | for fuse_op in self.fuse2cond.keys(): |
| | fuse_op_conditions = self.fuse2cond[fuse_op] |
| | if fuse_op == 'sum' and len(fuse_op_conditions) > 0: |
| | for cond in fuse_op_conditions: |
| | this_cond_1, this_cond_2, cond_mask = conditions[cond] |
| | fused_input_1 += this_cond_1 |
| | fused_input_2 += this_cond_2 |
| | elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0: |
| | if not first_step: |
| | continue |
| | reverse_list = deepcopy(fuse_op_conditions) |
| | reverse_list.reverse() |
| | for cond in reverse_list: |
| | this_cond_1, this_cond_2, cond_mask = conditions[cond] |
| | fused_input_1 = torch.cat((this_cond_1, fused_input_1), dim=1) |
| | fused_input_2 = torch.cat((this_cond_2, fused_input_2), dim=1) |
| | elif fuse_op not in self.FUSING_METHODS: |
| | raise ValueError(f"unknown op ({fuse_op})") |
| |
|
| | if self._is_streaming: |
| | self._streaming_state['offsets'] = offsets + T |
| |
|
| | return fused_input_1, fused_input_2 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | class DropoutModule(nn.Module): |
| | """Base module for all dropout modules.""" |
| | def __init__(self, seed: int = 1234): |
| | super().__init__() |
| | self.rng = torch.Generator() |
| | self.rng.manual_seed(seed) |
| | |
| |
|
| |
|
| | class ClassifierFreeGuidanceDropout(DropoutModule): |
| | """Classifier Free Guidance dropout. |
| | All attributes are dropped with the same probability. |
| | |
| | Args: |
| | p (float): Probability to apply condition dropout during training. |
| | seed (int): Random seed. |
| | """ |
| | def __init__(self, p: float, seed: int = 1234): |
| | super().__init__(seed=seed) |
| | self.p = p |
| |
|
| | def check(self, sample, condition_type, condition): |
| | |
| | if condition_type not in ['text', 'audio']: |
| | raise ValueError("dropout_condition got an unexpected condition type!" |
| | f" expected 'text', 'audio' but got '{condition_type}'") |
| |
|
| | if condition not in getattr(sample, condition_type): |
| | raise ValueError( |
| | "dropout_condition received an unexpected condition!" |
| | f" expected audio={sample.audio.keys()} and text={sample.text.keys()}" |
| | f" but got '{condition}' of type '{condition_type}'!") |
| | |
| | |
| | def get_null_wav(self, wav, sr=48000) -> AudioCondition: |
| | out = wav * 0 + 16385 |
| | return AudioCondition( |
| | wav=out, |
| | length=torch.Tensor([0]).long(), |
| | sample_rate=[sr],) |
| | |
| | def dropout_condition(self, |
| | sample: ConditioningAttributes, |
| | condition_type: str, |
| | condition: str) -> ConditioningAttributes: |
| | """Utility function for nullifying an attribute inside an ConditioningAttributes object. |
| | If the condition is of type "wav", then nullify it using `nullify_condition` function. |
| | If the condition is of any other type, set its value to None. |
| | Works in-place. |
| | """ |
| | self.check(sample, condition_type, condition) |
| | |
| | if condition_type == 'audio': |
| | audio_cond = sample.audio[condition] |
| | depth = audio_cond.wav.shape[1] |
| | sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0]) |
| | else: |
| | sample.text[condition] = None |
| |
|
| | return sample |
| | |
| | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: |
| | """ |
| | Args: |
| | samples (list[ConditioningAttributes]): List of conditions. |
| | Returns: |
| | list[ConditioningAttributes]: List of conditions after all attributes were set to None. |
| | """ |
| | |
| | |
| | |
| | |
| |
|
| | |
| | samples = deepcopy(samples) |
| |
|
| | for sample in samples: |
| | drop = torch.rand(1, generator=self.rng).item() |
| | if drop<self.p: |
| | for condition_type in ["audio", "text"]: |
| | for condition in sample.attributes[condition_type]: |
| | self.dropout_condition(sample, condition_type, condition) |
| | return samples |
| |
|
| | def __repr__(self): |
| | return f"ClassifierFreeGuidanceDropout(p={self.p})" |
| | |
| | |
| | class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout): |
| | """Classifier Free Guidance dropout during inference. |
| | All attributes are dropped with the same probability. |
| | |
| | Args: |
| | p (float): Probability to apply condition dropout during training. |
| | seed (int): Random seed. |
| | """ |
| | def __init__(self, seed: int = 1234): |
| | super().__init__(p=1, seed=seed) |
| |
|
| | def dropout_condition_customized(self, |
| | sample: ConditioningAttributes, |
| | condition_type: str, |
| | condition: str, |
| | customized: list = None) -> ConditioningAttributes: |
| | """Utility function for nullifying an attribute inside an ConditioningAttributes object. |
| | If the condition is of type "audio", then nullify it using `nullify_condition` function. |
| | If the condition is of any other type, set its value to None. |
| | Works in-place. |
| | """ |
| | self.check(sample, condition_type, condition) |
| |
|
| | if condition_type == 'audio': |
| | audio_cond = sample.audio[condition] |
| | depth = audio_cond.wav.shape[1] |
| | sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0]) |
| | else: |
| | if customized is None: |
| | sample.text[condition] = None |
| | else: |
| | text_cond = deepcopy(sample.text[condition]) |
| | if "structure" in customized: |
| | for _s in ['[inst]', '[outro]', '[intro]', '[verse]', '[chorus]', '[bridge]']: |
| | text_cond = text_cond.replace(_s, "") |
| | text_cond = text_cond.replace(' , ', '') |
| | text_cond = text_cond.replace(" ", " ") |
| | if '.' in customized: |
| | text_cond = text_cond.replace(" . ", " ") |
| | text_cond = text_cond.replace(".", " ") |
| | |
| | sample.text[condition] = text_cond |
| |
|
| | return sample |
| |
|
| | def forward(self, samples: tp.List[ConditioningAttributes], |
| | condition_types=["wav", "text"], |
| | customized=None, |
| | ) -> tp.List[ConditioningAttributes]: |
| | """ |
| | 100% dropout some condition attributes (description, prompt_wav) or types (text, wav) of |
| | samples during inference. |
| | |
| | Args: |
| | samples (list[ConditioningAttributes]): List of conditions. |
| | Returns: |
| | list[ConditioningAttributes]: List of conditions after all attributes were set to None. |
| | """ |
| | new_samples = deepcopy(samples) |
| | for condition_type in condition_types: |
| | for sample in new_samples: |
| | for condition in sample.attributes[condition_type]: |
| | self.dropout_condition_customized(sample, condition_type, condition, customized) |
| | return new_samples |
| | |
| | class AttributeDropout(ClassifierFreeGuidanceDropout): |
| | """Dropout with a given probability per attribute. |
| | This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes |
| | to be dropped out separately. For example, "artist" can be dropped while "genre" remains. |
| | This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" |
| | must also be dropped. |
| | |
| | Args: |
| | p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: |
| | ... |
| | "genre": 0.1, |
| | "artist": 0.5, |
| | "audio": 0.25, |
| | ... |
| | active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. |
| | seed (int, optional): Random seed. |
| | """ |
| | def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): |
| | super().__init__(p=p, seed=seed) |
| | self.active_on_eval = active_on_eval |
| | |
| | self.p = {} |
| | for condition_type, probs in p.items(): |
| | self.p[condition_type] = defaultdict(lambda: 0, probs) |
| | |
| | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: |
| | """ |
| | Args: |
| | samples (list[ConditioningAttributes]): List of conditions. |
| | Returns: |
| | list[ConditioningAttributes]: List of conditions after certain attributes were set to None. |
| | """ |
| | if not self.training and not self.active_on_eval: |
| | return samples |
| |
|
| | samples = deepcopy(samples) |
| | for condition_type, ps in self.p.items(): |
| | for condition, p in ps.items(): |
| | if torch.rand(1, generator=self.rng).item() < p: |
| | for sample in samples: |
| | self.dropout_condition(sample, condition_type, condition) |
| | return samples |
| |
|