File size: 3,637 Bytes
aae3ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch

def move_masked_to_left_brute_force(tensor, mask):
    results = []
    new_mask = []
    for i in range(tensor.shape[0]):
        t = torch.cat([tensor[i][mask[i]], torch.zeros_like(tensor[i][~mask[i]])])
        results.append(t)
        l = [True] * mask[i].sum() + [False] * (~mask[i]).sum()
        new_mask.append(l)
    results = torch.stack(results)
    new_mask = torch.tensor(new_mask, dtype=torch.bool)
    max_length = mask.sum(dim=1).max()

    results = results[:, :max_length]
    new_mask = new_mask[:, :max_length]
    return results, new_mask


def move_masked_to_left_ids(tensor, mask, pad_zero=True):
    masked_index = mask.cumsum(dim=1) - 1
    unmasked_index = (~mask).cumsum(dim=1) - 1
    unmasked_index += mask.sum(dim=1).unsqueeze(1)
    s2t_index = torch.where(mask, masked_index, unmasked_index)
    t2s_index = torch.argsort(s2t_index, dim=1)
    result = torch.gather(tensor, 1, t2s_index)
    
    length = mask.sum(dim=1)
    result = result[:, :length.max()]
    new_mask = torch.arange(result.shape[1], device=length.device).unsqueeze(0) < length.unsqueeze(1)
    if pad_zero:
        result[~new_mask] = 0
    return result, new_mask


def move_masked_to_left(tensor, mask, pad_zero=True):
    masked_index = mask.cumsum(dim=1) - 1
    unmasked_index = (~mask).cumsum(dim=1) - 1
    unmasked_index += mask.sum(dim=1).unsqueeze(1)
    s2t_index = torch.where(mask, masked_index, unmasked_index)
    t2s_index = torch.argsort(s2t_index, dim=1)
    result = torch.gather(tensor, 1, t2s_index.unsqueeze(2).expand(-1, -1, tensor.shape[2]))
    
    length = mask.sum(dim=1)
    result = result[:, :length.max()]
    new_mask = torch.arange(result.shape[1], device=length.device).unsqueeze(0) < length.unsqueeze(1)
    if pad_zero:
        result[~new_mask] = 0
    return result, new_mask

def get_mask_of_last_masked_index_brute_force(mask, length):
    results = []
    for i in range(mask.shape[0]):
        len = length if isinstance(length, int) else length[i].item()
        l = [False] * mask.shape[1]
        for j in range(mask.shape[1] - 1, -1, -1):
            if mask[i][j] and len > 0:
                l[j] = True
                len -= 1
            else:
                l[j] = False
        results.append(l)
    return torch.tensor(results, dtype=torch.bool)

def get_mask_of_last_masked_index(mask, length):
    cumsum = mask.cumsum(dim=1)
    new_length = mask.sum(dim=1) - length
    last_masked_index = (cumsum > new_length.unsqueeze(1)) & mask
    return last_masked_index


def test_move_masked_to_left():
    b = 10
    n = 20
    tensor = torch.randn(b, n, 5)
    mask = torch.randint(0, 2, (b, n)).bool()
    # print(tensor)
    # print(mask)
    result_1, mask_1 = move_masked_to_left(tensor, mask)
    result_2, mask_2 = move_masked_to_left_brute_force(tensor, mask)
    # assert (result_1[mask_1] == result_2[mask_2]).all()
    assert (result_1 == result_2).all()
    assert (mask_1 == mask_2).all()
    assert (mask.sum(dim=1) == mask_1.sum(dim=1)).all()

    for i in range(b):
        l = mask[i].sum()
        assert (mask_1[i][:l].all() == True)
        assert (mask_1[i][l:].any() == False)

def test_get_mask_of_last_masked_index():
    b = 10
    n = 20
    mask = torch.randint(0, 2, (b, n)).bool()
    length = torch.randint(0, n//2, (b,))
    last_masked_index_1 = get_mask_of_last_masked_index(mask, length)
    last_masked_index_2 = get_mask_of_last_masked_index_brute_force(mask, length)
    assert (last_masked_index_1 == last_masked_index_2).all()

if __name__ == '__main__':
    test_move_masked_to_left()
    test_get_mask_of_last_masked_index()