| import torch | |
| def create_mask_sequence(mask_cfg, seq_len): | |
| type_name = mask_cfg['type'] | |
| if type_name == 'raster order': | |
| num_tokens = mask_cfg['num_tokens'] | |
| idx_list = [] | |
| all_idx = torch.arange(seq_len) | |
| for i in range(0, seq_len, num_tokens): | |
| idx_list.append(all_idx[i: i + num_tokens]) | |
| return idx_list | |
| elif type_name == 'random order': | |
| num_tokens = mask_cfg['num_tokens'] | |
| idx_list = [] | |
| all_idx = torch.randperm(seq_len) | |
| for i in range(0, seq_len, num_tokens): | |
| idx_list.append(all_idx[i: i + num_tokens]) | |
| return idx_list | |
| elif type_name == 'single': | |
| idx_list = [torch.arange(seq_len)] | |
| return idx_list | |
| else: | |
| raise NotImplementedError() | |