| import torch | |
| def attnmask_new(sz): | |
| return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) | |
| def attnmask_old(sz): | |
| return torch.log(torch.tril(torch.ones(sz, sz))) | |
| def compare_masks(sz): | |
| new_mask = attnmask_new(sz) | |
| old_mask = attnmask_old(sz) | |
| # print("New Mask:") | |
| # print(new_mask) | |
| # print("\nOld Mask:") | |
| # print(old_mask) | |
| if torch.equal(new_mask, old_mask): | |
| #print("\nThe masks are equal.") | |
| return True | |
| else: | |
| print("\nThe masks are NOT equal.") | |
| raise ValueError("Masks differ, check implementation.") | |
| if __name__ == "__main__": | |
| for i in range(1, 100): | |
| # print(f"Comparing masks of size {i}:") | |
| compare_masks(i) | |
| #print("\n" + "="*50 + "\n") | |
| print("All masks are equal for sizes 1 to 99.") | |
| # test time taken for size 256 of each implementation usint timeit | |
| import timeit | |
| size = 256 | |
| new_time = timeit.timeit(lambda: attnmask_new(size), number=1000) | |
| old_time = timeit.timeit(lambda: attnmask_old(size), number=1000) | |
| print(f"New mask time for size {size}: {new_time:.6f} seconds") | |
| print(f"Old mask time for size {size}: {old_time:.6f} seconds") | |
| if new_time < old_time: | |
| print("New mask implementation is faster.") | |
| else: | |
| print("Old mask implementation is faster.") |