Spaces:
Runtime error
Runtime error
| # Copyright 2024 Xiaomi Corp. (authors: Wei Kang | |
| # Han Zhu) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Optional | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from zipvoice.models.modules.solver import EulerSolver | |
| from zipvoice.models.modules.zipformer import TTSZipformer | |
| from zipvoice.utils.common import ( | |
| condition_time_mask, | |
| get_tokens_index, | |
| make_pad_mask, | |
| pad_labels, | |
| prepare_avg_tokens_durations, | |
| ) | |
| def score_tokens(A): | |
| B = [9, 14, 18, 21, 27, 33, 37, 39, 42, 45, 50, 51, 52, 54, 58, 59, 61, 62, 63, 69, 73, 74, 79, 85, 99, 100, 102, 105, 119, 120, 121, 122, 123, 124, 141, 143, 144, 145, 146, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 349, 350, 353, 356, 357, 358, 359] | |
| total_score = 0 | |
| # ThΓͺm 3 vΓ o ΔαΊ§u vΓ cuα»i | |
| tokens = [3] + A + [3] | |
| # TΓ‘ch chuα»i theo sα» 3 | |
| segment = [] | |
| for t in tokens: | |
| if t == 3: | |
| if segment: # xα» lΓ½ 1 ΔoαΊ‘n | |
| count = 0 | |
| for i in range(len(segment) - 1): | |
| if (segment[i] in B and segment[i+1] not in B): | |
| # print(f"{segment[i]} in B and {segment[i+1]} not in B)") | |
| count += 1 | |
| if segment[-1] in B: | |
| # print(f"{segment[-1]} in B") | |
| count += 1 | |
| if count > 0: | |
| total_score += 1 + (count - 1) * 0.85 | |
| segment = [] | |
| else: | |
| segment.append(t) | |
| return total_score | |
| class ZipVoice(nn.Module): | |
| """The ZipVoice model.""" | |
| def __init__( | |
| self, | |
| fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1], | |
| fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4], | |
| fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31], | |
| fm_decoder_feedforward_dim: int = 1536, | |
| fm_decoder_num_heads: int = 4, | |
| fm_decoder_dim: int = 512, | |
| text_encoder_num_layers: int = 4, | |
| text_encoder_feedforward_dim: int = 512, | |
| text_encoder_cnn_module_kernel: int = 9, | |
| text_encoder_num_heads: int = 4, | |
| text_encoder_dim: int = 192, | |
| time_embed_dim: int = 192, | |
| text_embed_dim: int = 192, | |
| query_head_dim: int = 32, | |
| value_head_dim: int = 12, | |
| pos_head_dim: int = 4, | |
| pos_dim: int = 48, | |
| feat_dim: int = 100, | |
| vocab_size: int = 26, | |
| pad_id: int = 0, | |
| ): | |
| """ | |
| Initialize the model with specified configuration parameters. | |
| Args: | |
| fm_decoder_downsampling_factor: List of downsampling factors for each layer | |
| in the flow-matching decoder. | |
| fm_decoder_num_layers: List of the number of layers for each block in the | |
| flow-matching decoder. | |
| fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the | |
| flow-matching decoder. | |
| fm_decoder_feedforward_dim: Dimension of the feedforward network in the | |
| flow-matching decoder. | |
| fm_decoder_num_heads: Number of attention heads in the flow-matching | |
| decoder. | |
| fm_decoder_dim: Hidden dimension of the flow-matching decoder. | |
| text_encoder_num_layers: Number of layers in the text encoder. | |
| text_encoder_feedforward_dim: Dimension of the feedforward network in the | |
| text encoder. | |
| text_encoder_cnn_module_kernel: Kernel size for the CNN module in the | |
| text encoder. | |
| text_encoder_num_heads: Number of attention heads in the text encoder. | |
| text_encoder_dim: Hidden dimension of the text encoder. | |
| time_embed_dim: Dimension of the time embedding. | |
| text_embed_dim: Dimension of the text embedding. | |
| query_head_dim: Dimension of the query attention head. | |
| value_head_dim: Dimension of the value attention head. | |
| pos_head_dim: Dimension of the position attention head. | |
| pos_dim: Dimension of the positional encoding. | |
| feat_dim: Dimension of the acoustic features. | |
| vocab_size: Size of the vocabulary. | |
| pad_id: ID used for padding tokens. | |
| """ | |
| super().__init__() | |
| self.fm_decoder = TTSZipformer( | |
| in_dim=feat_dim * 3, | |
| out_dim=feat_dim, | |
| downsampling_factor=fm_decoder_downsampling_factor, | |
| num_encoder_layers=fm_decoder_num_layers, | |
| cnn_module_kernel=fm_decoder_cnn_module_kernel, | |
| encoder_dim=fm_decoder_dim, | |
| feedforward_dim=fm_decoder_feedforward_dim, | |
| num_heads=fm_decoder_num_heads, | |
| query_head_dim=query_head_dim, | |
| pos_head_dim=pos_head_dim, | |
| value_head_dim=value_head_dim, | |
| pos_dim=pos_dim, | |
| use_time_embed=True, | |
| time_embed_dim=time_embed_dim, | |
| ) | |
| self.text_encoder = TTSZipformer( | |
| in_dim=text_embed_dim, | |
| out_dim=feat_dim, | |
| downsampling_factor=1, | |
| num_encoder_layers=text_encoder_num_layers, | |
| cnn_module_kernel=text_encoder_cnn_module_kernel, | |
| encoder_dim=text_encoder_dim, | |
| feedforward_dim=text_encoder_feedforward_dim, | |
| num_heads=text_encoder_num_heads, | |
| query_head_dim=query_head_dim, | |
| pos_head_dim=pos_head_dim, | |
| value_head_dim=value_head_dim, | |
| pos_dim=pos_dim, | |
| use_time_embed=False, | |
| ) | |
| self.feat_dim = feat_dim | |
| self.text_embed_dim = text_embed_dim | |
| self.pad_id = pad_id | |
| self.embed = nn.Embedding(vocab_size, text_embed_dim) | |
| self.solver = EulerSolver(self, func_name="forward_fm_decoder") | |
| def forward_fm_decoder( | |
| self, | |
| t: torch.Tensor, | |
| xt: torch.Tensor, | |
| text_condition: torch.Tensor, | |
| speech_condition: torch.Tensor, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| guidance_scale: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Compute velocity. | |
| Args: | |
| t: A tensor of shape (N, 1, 1) or a tensor of a float, | |
| in the range of (0, 1). | |
| xt: the input of the current timestep, including condition | |
| embeddings and noisy acoustic features. | |
| text_condition: the text condition embeddings, with the | |
| shape (batch, seq_len, emb_dim). | |
| speech_condition: the speech condition embeddings, with the | |
| shape (batch, seq_len, emb_dim). | |
| padding_mask: The mask for padding, True means masked | |
| position, with the shape (N, T). | |
| guidance_scale: The guidance scale in classifier-free guidance, | |
| which is a tensor of shape (N, 1, 1) or a tensor of a float. | |
| Returns: | |
| predicted velocity, with the shape (batch, seq_len, emb_dim). | |
| """ | |
| xt = torch.cat([xt, text_condition, speech_condition], dim=2) | |
| assert t.dim() in (0, 3) | |
| # Handle t with the shape (N, 1, 1): | |
| # squeeze the last dimension if it's size is 1. | |
| while t.dim() > 1 and t.size(-1) == 1: | |
| t = t.squeeze(-1) | |
| # Handle t with a single value: expand to the size of batch size. | |
| if t.dim() == 0: | |
| t = t.repeat(xt.shape[0]) | |
| if guidance_scale is not None: | |
| while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1: | |
| guidance_scale = guidance_scale.squeeze(-1) | |
| if guidance_scale.dim() == 0: | |
| guidance_scale = guidance_scale.repeat(xt.shape[0]) | |
| vt = self.fm_decoder( | |
| x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale | |
| ) | |
| else: | |
| vt = self.fm_decoder(x=xt, t=t, padding_mask=padding_mask) | |
| return vt | |
| def forward_text_embed( | |
| self, | |
| tokens: List[List[int]], | |
| ): | |
| """ | |
| Get the text embeddings. | |
| Args: | |
| tokens: a list of list of token ids. | |
| Returns: | |
| embed: the text embeddings, shape (batch, seq_len, emb_dim). | |
| tokens_lens: the length of each token sequence, shape (batch,). | |
| """ | |
| device = ( | |
| self.device if isinstance(self, DDP) else next(self.parameters()).device | |
| ) | |
| tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S) | |
| embed = self.embed(tokens_padded) # (B, S, C) | |
| tokens_lens = torch.tensor( | |
| [len(token) for token in tokens], dtype=torch.int64, device=device | |
| ) | |
| tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S) | |
| embed = self.text_encoder( | |
| x=embed, t=None, padding_mask=tokens_padding_mask | |
| ) # (B, S, C) | |
| return embed, tokens_lens | |
| def forward_text_condition( | |
| self, | |
| embed: torch.Tensor, | |
| tokens_lens: torch.Tensor, | |
| features_lens: torch.Tensor, | |
| ): | |
| """ | |
| Get the text condition with the same length of the acoustic feature. | |
| Args: | |
| embed: the text embeddings, shape (batch, token_seq_len, emb_dim). | |
| tokens_lens: the length of each token sequence, shape (batch,). | |
| features_lens: the length of each acoustic feature sequence, | |
| shape (batch,). | |
| Returns: | |
| text_condition: the text condition, shape | |
| (batch, feature_seq_len, emb_dim). | |
| padding_mask: the padding mask of text condition, shape | |
| (batch, feature_seq_len). | |
| """ | |
| num_frames = int(features_lens.max()) | |
| padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T) | |
| tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens) | |
| tokens_index = get_tokens_index(tokens_durations, num_frames).to( | |
| embed.device | |
| ) # (B, T) | |
| text_condition = torch.gather( | |
| embed, | |
| dim=1, | |
| index=tokens_index.unsqueeze(-1).expand( | |
| embed.size(0), num_frames, embed.size(-1) | |
| ), | |
| ) # (B, T, F) | |
| return text_condition, padding_mask | |
| def forward_text_train( | |
| self, | |
| tokens: List[List[int]], | |
| features_lens: torch.Tensor, | |
| ): | |
| """ | |
| Process text for training, given text tokens and real feature lengths. | |
| """ | |
| embed, tokens_lens = self.forward_text_embed(tokens) | |
| text_condition, padding_mask = self.forward_text_condition( | |
| embed, tokens_lens, features_lens | |
| ) | |
| return ( | |
| text_condition, | |
| padding_mask, | |
| ) | |
| def forward_text_inference_gt_duration( | |
| self, | |
| tokens: List[List[int]], | |
| features_lens: torch.Tensor, | |
| prompt_tokens: List[List[int]], | |
| prompt_features_lens: torch.Tensor, | |
| ): | |
| """ | |
| Process text for inference, given text tokens, real feature lengths and prompts. | |
| """ | |
| tokens = [ | |
| prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens) | |
| ] | |
| features_lens = prompt_features_lens + features_lens | |
| embed, tokens_lens = self.forward_text_embed(tokens) | |
| text_condition, padding_mask = self.forward_text_condition( | |
| embed, tokens_lens, features_lens | |
| ) | |
| return text_condition, padding_mask | |
| def forward_text_inference_ratio_duration( | |
| self, | |
| tokens: List[List[int]], | |
| prompt_tokens: List[List[int]], | |
| prompt_features_lens: torch.Tensor, | |
| speed: float, | |
| num_space_text=[-1], | |
| num_space_prompt=[-1] | |
| ): | |
| """ | |
| Process text for inference, length Ζ°α»c lượng theo sα» khoαΊ£ng trαΊ―ng (token=3) + 1. | |
| """ | |
| device = ( | |
| self.device if isinstance(self, DDP) else next(self.parameters()).device | |
| ) | |
| cat_tokens = [ | |
| prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens) | |
| ] | |
| prompt_tokens_lens = torch.tensor( | |
| [len(token) for token in prompt_tokens], | |
| dtype=torch.int64, | |
| device=device, | |
| ) | |
| tokens_lens = torch.tensor( | |
| [len(token) for token in tokens], | |
| dtype=torch.int64, | |
| device=device, | |
| ) | |
| # π sα» khoαΊ£ng trαΊ―ng + 1 | |
| prompt_space_lens = torch.tensor( | |
| [(score_tokens(token)-(token.count(3) - numsp - 1) + token.count(8)*0.5)*100 for token, numsp in zip(prompt_tokens,num_space_prompt)], | |
| dtype=torch.int64, | |
| device=device, | |
| ) | |
| tokens_space_lens = torch.tensor( | |
| [(score_tokens(token)-(token.count(3) - numsp - 1) + token.count(8)*0.5)*100 for token, numsp in zip (tokens,num_space_text)], | |
| dtype=torch.int64, | |
| device=device, | |
| ) | |
| print("tokens_space_lens: ", tokens_space_lens) | |
| print("prompt_space_lens: ", prompt_space_lens) | |
| cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens) | |
| def alpha(prompt_space_lens: float) -> float: | |
| if prompt_space_lens <= 1: | |
| return 1.1 | |
| elif prompt_space_lens >= 30: | |
| return 1.03 | |
| else: | |
| return 1.1 - (prompt_space_lens - 1) / (30 - 1) * (1.1 - 1.03) | |
| # frames_per_word * sα»_word_trong_text | |
| features_lens = prompt_features_lens + torch.ceil( | |
| (prompt_features_lens / prompt_space_lens * tokens_space_lens / speed * alpha(tokens_space_lens/100)) | |
| ).to(dtype=torch.int64) | |
| text_condition, padding_mask = self.forward_text_condition( | |
| cat_embed, cat_tokens_lens, features_lens | |
| ) | |
| return text_condition, padding_mask | |
| def forward( | |
| self, | |
| tokens: List[List[int]], | |
| features: torch.Tensor, | |
| features_lens: torch.Tensor, | |
| noise: torch.Tensor, | |
| t: torch.Tensor, | |
| condition_drop_ratio: float = 0.0, | |
| ) -> torch.Tensor: | |
| """Forward pass of the model for training. | |
| Args: | |
| tokens: a list of list of token ids. | |
| features: the acoustic features, with the shape (batch, seq_len, feat_dim). | |
| features_lens: the length of each acoustic feature sequence, shape (batch,). | |
| noise: the intitial noise, with the shape (batch, seq_len, feat_dim). | |
| t: the time step, with the shape (batch, 1, 1). | |
| condition_drop_ratio: the ratio of dropped text condition. | |
| Returns: | |
| fm_loss: the flow-matching loss. | |
| """ | |
| (text_condition, padding_mask,) = self.forward_text_train( | |
| tokens=tokens, | |
| features_lens=features_lens, | |
| ) | |
| speech_condition_mask = condition_time_mask( | |
| features_lens=features_lens, | |
| mask_percent=(0.7, 1.0), | |
| max_len=features.size(1), | |
| ) | |
| speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) | |
| if condition_drop_ratio > 0.0: | |
| drop_mask = ( | |
| torch.rand(text_condition.size(0), 1, 1).to(text_condition.device) | |
| > condition_drop_ratio | |
| ) | |
| text_condition = text_condition * drop_mask | |
| xt = features * t + noise * (1 - t) | |
| ut = features - noise # (B, T, F) | |
| vt = self.forward_fm_decoder( | |
| t=t, | |
| xt=xt, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| ) | |
| loss_mask = speech_condition_mask & (~padding_mask) | |
| fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2) | |
| return fm_loss | |
| def sample( | |
| self, | |
| tokens: List[List[int]], | |
| prompt_tokens: List[List[int]], | |
| prompt_features: torch.Tensor, | |
| prompt_features_lens: torch.Tensor, | |
| features_lens: Optional[torch.Tensor] = None, | |
| speed: float = 1.0, | |
| t_shift: float = 1.0, | |
| duration: str = "predict", | |
| num_step: int = 5, | |
| guidance_scale: float = 0.5, | |
| num_space_text=[-1], | |
| num_space_prompt=[-1] | |
| ) -> torch.Tensor: | |
| """ | |
| Generate acoustic features, given text tokens, prompts feature | |
| and prompt transcription's text tokens. | |
| Args: | |
| tokens: a list of list of text tokens. | |
| prompt_tokens: a list of list of prompt tokens. | |
| prompt_features: the prompt feature with the shape | |
| (batch_size, seq_len, feat_dim). | |
| prompt_features_lens: the length of each prompt feature, | |
| with the shape (batch_size,). | |
| features_lens: the length of the predicted eature, with the | |
| shape (batch_size,). It is used only when duration is "real". | |
| duration: "real" or "predict". If "real", the predicted | |
| feature length is given by features_lens. | |
| num_step: the number of steps to use in the ODE solver. | |
| guidance_scale: the guidance scale for classifier-free guidance. | |
| """ | |
| assert duration in ["real", "predict"] | |
| if duration == "predict": | |
| ( | |
| text_condition, | |
| padding_mask, | |
| ) = self.forward_text_inference_ratio_duration( | |
| tokens=tokens, | |
| prompt_tokens=prompt_tokens, | |
| prompt_features_lens=prompt_features_lens, | |
| speed=speed, | |
| num_space_text=num_space_text, | |
| num_space_prompt=num_space_prompt, | |
| ) | |
| else: | |
| assert features_lens is not None | |
| text_condition, padding_mask = self.forward_text_inference_gt_duration( | |
| tokens=tokens, | |
| features_lens=features_lens, | |
| prompt_tokens=prompt_tokens, | |
| prompt_features_lens=prompt_features_lens, | |
| ) | |
| batch_size, num_frames, _ = text_condition.shape | |
| speech_condition = torch.nn.functional.pad( | |
| prompt_features, (0, 0, 0, num_frames - prompt_features.size(1)) | |
| ) # (B, T, F) | |
| # False means speech condition positions. | |
| speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames) | |
| speech_condition = torch.where( | |
| speech_condition_mask.unsqueeze(-1), | |
| torch.zeros_like(speech_condition), | |
| speech_condition, | |
| ) | |
| x0 = torch.randn( | |
| batch_size, | |
| num_frames, | |
| prompt_features.size(-1), | |
| device=text_condition.device, | |
| ) | |
| x1 = self.solver.sample( | |
| x=x0, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| num_step=num_step, | |
| guidance_scale=guidance_scale, | |
| t_shift=t_shift, | |
| ) | |
| x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens | |
| x1_prompt = torch.zeros( | |
| x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device | |
| ) | |
| x1_wo_prompt = torch.zeros( | |
| x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device | |
| ) | |
| for i in range(x1.size(0)): | |
| x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[ | |
| i, | |
| prompt_features_lens[i] : prompt_features_lens[i] | |
| + x1_wo_prompt_lens[i], | |
| ] | |
| x1_prompt[i, : prompt_features_lens[i], :] = x1[ | |
| i, : prompt_features_lens[i] | |
| ] | |
| return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens | |
| def sample_intermediate( | |
| self, | |
| tokens: List[List[int]], | |
| features: torch.Tensor, | |
| features_lens: torch.Tensor, | |
| noise: torch.Tensor, | |
| speech_condition_mask: torch.Tensor, | |
| t_start: float, | |
| t_end: float, | |
| num_step: int = 1, | |
| guidance_scale: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Generate acoustic features in intermediate timesteps. | |
| Args: | |
| tokens: List of list of token ids. | |
| features: The acoustic features, with the shape (batch, seq_len, feat_dim). | |
| features_lens: The length of each acoustic feature sequence, | |
| with the shape (batch,). | |
| noise: The initial noise, with the shape (batch, seq_len, feat_dim). | |
| speech_condition_mask: The mask for speech condition, True means | |
| non-condition positions, with the shape (batch, seq_len). | |
| t_start: The start timestep. | |
| t_end: The end timestep. | |
| num_step: The number of steps for sampling. | |
| guidance_scale: The scale for classifier-free guidance inference, | |
| with the shape (batch, 1, 1). | |
| """ | |
| (text_condition, padding_mask,) = self.forward_text_train( | |
| tokens=tokens, | |
| features_lens=features_lens, | |
| ) | |
| speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) | |
| x_t_end = self.solver.sample( | |
| x=noise, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| num_step=num_step, | |
| guidance_scale=guidance_scale, | |
| t_start=t_start, | |
| t_end=t_end, | |
| ) | |
| x_t_end_lens = (~padding_mask).sum(-1) | |
| return x_t_end, x_t_end_lens | |