| | |
| | |
| | |
| | |
| |
|
| | import torch |
| |
|
| |
|
| | def slice_segments(x, ids_str, segment_size=200): |
| | ret = torch.zeros_like(x[:, :, :segment_size]) |
| | for i in range(x.size(0)): |
| | idx_str = ids_str[i] |
| | idx_end = idx_str + segment_size |
| | ret[i] = x[i, :, idx_str:idx_end] |
| | return ret |
| |
|
| |
|
| | def rand_ids_segments(lengths, segment_size=200): |
| | b = lengths.shape[0] |
| | ids_str_max = lengths - segment_size |
| | ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to( |
| | dtype=torch.long |
| | ) |
| | return ids_str |
| |
|
| |
|
| | def fix_len_compatibility(length, num_downsamplings_in_unet=2): |
| | while True: |
| | if length % (2**num_downsamplings_in_unet) == 0: |
| | return length |
| | length += 1 |
| |
|