| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ Benchmarking the library on inference and training """ |
| |
|
| | from transformers import HfArgumentParser, PyTorchBenchmark, PyTorchBenchmarkArguments |
| |
|
| |
|
| | def main(): |
| | parser = HfArgumentParser(PyTorchBenchmarkArguments) |
| | try: |
| | benchmark_args = parser.parse_args_into_dataclasses()[0] |
| | except ValueError as e: |
| | arg_error_msg = "Arg --no_{0} is no longer used, please use --no-{0} instead." |
| | begin_error_msg = " ".join(str(e).split(" ")[:-1]) |
| | full_error_msg = "" |
| | depreciated_args = eval(str(e).split(" ")[-1]) |
| | wrong_args = [] |
| | for arg in depreciated_args: |
| | |
| | if arg[2:] in PyTorchBenchmarkArguments.deprecated_args: |
| | |
| | full_error_msg += arg_error_msg.format(arg[5:]) |
| | else: |
| | wrong_args.append(arg) |
| | if len(wrong_args) > 0: |
| | full_error_msg = full_error_msg + begin_error_msg + str(wrong_args) |
| | raise ValueError(full_error_msg) |
| |
|
| | benchmark = PyTorchBenchmark(args=benchmark_args) |
| | benchmark.run() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|