from typing import List import torch def check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tensor, result: torch.Tensor): """ Return if two tensors are bitwise equal Return a bool if avoid_sync is False, else return a tensor """ assert ans.shape == ref.shape, "Shape mismatch" torch.all(torch.eq(ans, ref), out=result) def check_is_bitwise_equal(name: str, ans: torch.Tensor, ref: torch.Tensor, quiet: bool = False) -> bool: is_bitwise_equal = torch.equal(ans, ref) if not quiet and not is_bitwise_equal: print(f"`{name}` mismatch: not bitwise equal. Mismatch count: {(ans != ref).sum().item()} out of {ans.numel()}") return is_bitwise_equal def get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float: """ Calculate the cosine diff between two tensors Return a float if avoid_sync is False, else return a tensor """ ans, ref = ans.double(), ref.double() if (ref*ref).sum().item() < 1e-12: return 0 denominator = (ans*ans + ref*ref).sum().item() sim = 2 * (ans*ref).sum().item() / denominator return 1 - sim def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7, quiet: bool = False) -> bool: """ Check if two tensors are close enough Return a bool if avoid_sync is False, else return a tensor """ assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" assert ans.dtype == ref.dtype, f"`{name}` Dtype mismatch: {ans.dtype} vs {ref.dtype}" ans = ans.clone().to(torch.float) ref = ref.clone().to(torch.float) def report_err(*args, **kwargs): if not quiet: print(*args, **kwargs) # Deal with anomalies def deal_with_anomalies(val: float): ref_mask = (ref == val) if (val == val) else (ref != ref) ans_mask = (ans == val) if (val == val) else (ans != ans) ref[ref_mask] = 0.0 ans[ans_mask] = 0.0 if not torch.equal(ref_mask, ans_mask): report_err(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") return False return True anomalies_check_passed = True anomalies_check_passed &= deal_with_anomalies(float("inf")) anomalies_check_passed &= deal_with_anomalies(float("-inf")) anomalies_check_passed &= deal_with_anomalies(float("nan")) cos_diff = get_cos_diff(ans, ref) raw_abs_err = torch.abs(ans-ref) raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) rel_err = raw_rel_err.masked_fill(raw_abs_err List[int]: result = [] for size in t.shape[::-1]: result.append(pos % size) pos = pos // size assert pos == 0 return result[::-1] report_err(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") report_err(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") report_err(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") report_err(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") return False else: if abs(cos_diff) > cos_diff_tol: report_err(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") return False return True def check_is_allclose_comparator(name: str, ans: torch.Tensor, ref: torch.Tensor, out: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): out.fill_(check_is_allclose(name, ans, ref, abs_tol, rel_tol, cos_diff_tol))