File size: 4,421 Bytes
ccef021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import List

import torch

def check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tensor, result: torch.Tensor):
    """
    Return if two tensors are bitwise equal
    Return a bool if avoid_sync is False, else return a tensor
    """
    assert ans.shape == ref.shape, "Shape mismatch"
    torch.all(torch.eq(ans, ref), out=result)

def check_is_bitwise_equal(name: str, ans: torch.Tensor, ref: torch.Tensor, quiet: bool = False) -> bool:
    is_bitwise_equal = torch.equal(ans, ref)
    if not quiet and not is_bitwise_equal:
        print(f"`{name}` mismatch: not bitwise equal. Mismatch count: {(ans != ref).sum().item()} out of {ans.numel()}")
    return is_bitwise_equal

def get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float:
    """
    Calculate the cosine diff between two tensors
    Return a float if avoid_sync is False, else return a tensor
    """
    ans, ref = ans.double(), ref.double()
    if (ref*ref).sum().item() < 1e-12:
        return 0
    denominator = (ans*ans + ref*ref).sum().item()
    sim = 2 * (ans*ref).sum().item() / denominator
    return 1 - sim

def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7, quiet: bool = False) -> bool:
    """
    Check if two tensors are close enough
    Return a bool if avoid_sync is False, else return a tensor
    """
    assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}"
    assert ans.dtype == ref.dtype, f"`{name}` Dtype mismatch: {ans.dtype} vs {ref.dtype}"
    
    ans = ans.clone().to(torch.float)
    ref = ref.clone().to(torch.float)

    def report_err(*args, **kwargs):
        if not quiet:
            print(*args, **kwargs)

    # Deal with anomalies
    def deal_with_anomalies(val: float):
        ref_mask = (ref == val) if (val == val) else (ref != ref)
        ans_mask = (ans == val) if (val == val) else (ans != ans)
        ref[ref_mask] = 0.0
        ans[ans_mask] = 0.0
        if not torch.equal(ref_mask, ans_mask):
            report_err(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref")
            return False
        return True
    
    anomalies_check_passed = True
    anomalies_check_passed &= deal_with_anomalies(float("inf"))
    anomalies_check_passed &= deal_with_anomalies(float("-inf"))
    anomalies_check_passed &= deal_with_anomalies(float("nan"))

    cos_diff = get_cos_diff(ans, ref)
    raw_abs_err = torch.abs(ans-ref)
    raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6))
    rel_err = raw_rel_err.masked_fill(raw_abs_err<abs_tol, 0)
    abs_err = raw_abs_err.masked_fill(raw_rel_err<rel_tol, 0)
    pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol)

    if not anomalies_check_passed:
        return False

    if not pass_mask.all():
        report_err(f"`{name}` mismatch")
        max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item()
        max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item()
        def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]:
            result = []
            for size in t.shape[::-1]:
                result.append(pos % size)
                pos = pos // size
            assert pos == 0
            return result[::-1]
        report_err(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}")
        report_err(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}")
        report_err(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)")
        report_err(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})")
        return False
    else:
        if abs(cos_diff) > cos_diff_tol:
            report_err(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})")
            return False
        return True
    
def check_is_allclose_comparator(name: str, ans: torch.Tensor, ref: torch.Tensor, out: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7):
    out.fill_(check_is_allclose(name, ans, ref, abs_tol, rel_tol, cos_diff_tol))