| from typing import List, Union | |
| import numpy as np | |
| import torch | |
| from transformer_deploy.benchmarks.utils import compare_outputs, to_numpy | |
| def check_accuracy( | |
| engine_name: str, | |
| pytorch_output: List[torch.Tensor], | |
| engine_output: List[Union[np.ndarray, torch.Tensor]], | |
| tolerance: float, | |
| ) -> None: | |
| """ | |
| Compare engine predictions with a reference. | |
| Assert that the difference is under a threshold. | |
| :param engine_name: string used in error message, if any | |
| :param pytorch_output: reference output used for the comparaison | |
| :param engine_output: output from the engine | |
| :param tolerance: if difference in outputs is above threshold, an error will be raised | |
| """ | |
| pytorch_output = to_numpy(pytorch_output) | |
| engine_output = to_numpy(engine_output) | |
| discrepency = compare_outputs(pytorch_output=pytorch_output, engine_output=engine_output) | |
| assert discrepency <= tolerance, ( | |
| f"{engine_name} discrepency is too high ({discrepency:.2f} >= {tolerance}):\n" | |
| f"Pythorch:\n{pytorch_output}\n" | |
| f"VS\n" | |
| f"Engine:\n{engine_output}\n" | |
| f"Diff:\n" | |
| f"{torch.asarray(pytorch_output) - torch.asarray(engine_output)}\n" | |
| "Tolerance can be increased with --atol parameter." | |
| ) | |