import torch def attnmask_new(sz): return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) def attnmask_old(sz): return torch.log(torch.tril(torch.ones(sz, sz))) def compare_masks(sz): new_mask = attnmask_new(sz) old_mask = attnmask_old(sz) # print("New Mask:") # print(new_mask) # print("\nOld Mask:") # print(old_mask) if torch.equal(new_mask, old_mask): #print("\nThe masks are equal.") return True else: print("\nThe masks are NOT equal.") raise ValueError("Masks differ, check implementation.") if __name__ == "__main__": for i in range(1, 100): # print(f"Comparing masks of size {i}:") compare_masks(i) #print("\n" + "="*50 + "\n") print("All masks are equal for sizes 1 to 99.") # test time taken for size 256 of each implementation usint timeit import timeit size = 256 new_time = timeit.timeit(lambda: attnmask_new(size), number=1000) old_time = timeit.timeit(lambda: attnmask_old(size), number=1000) print(f"New mask time for size {size}: {new_time:.6f} seconds") print(f"Old mask time for size {size}: {old_time:.6f} seconds") if new_time < old_time: print("New mask implementation is faster.") else: print("Old mask implementation is faster.")