Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| def create_tts_mask(seq_len, max_seq_len, mask_range): | |
| bs = seq_len.size(0) | |
| device = seq_len.device | |
| # 1. Sample random fractional lengths for each sequence | |
| frac_lengths = torch.zeros(bs, device=device).uniform_(*mask_range) | |
| # 2. Convert fractional lengths to integer lengths | |
| lengths = (frac_lengths * seq_len).long() | |
| # 3. Compute valid start indices based on sequence length | |
| max_start = seq_len - lengths | |
| # 4. Sample random start positions (clamped at 0) | |
| rand = torch.rand(bs, device=device) | |
| start = (max_start * rand).long().clamp(min=0) | |
| end = start + lengths | |
| # 5. Build the final boolean mask | |
| # max_seq_len = seq_len.max().item() | |
| seq = torch.arange(max_seq_len, device=device).long() | |
| start_mask = seq[None, :] >= start[:, None] | |
| end_mask = seq[None, :] < end[:, None] | |
| mask = start_mask & end_mask | |
| return mask | |
| if __name__ == "__main__": | |
| # Example: 3 sequences of lengths [5, 7, 6] | |
| lengths = torch.tensor([5, 7, 6]) | |
| mask_range = (0.7, 1.0) # Sample fractional lengths between 30% and 70% of each seq | |
| mask = create_tts_mask(lengths, mask_range) | |
| print("Mask shape:", mask.shape) # Should be [3, 7], since max_seq_len is 7 | |
| print(mask) |