| import argparse | |
| import glob | |
| import importlib | |
| import itertools | |
| import os | |
| import torch | |
| from common.bench_framework import (make_bwd_benchmark_for_case, | |
| make_bwd_benchmark_plot_for_case, | |
| make_fwd_benchmark_for_case, | |
| make_fwd_benchmark_plot_for_case) | |
| from common.diff_engine import DiffCase, calculate_diff | |
| def make_title_tag(): | |
| if torch.cuda.is_available(): | |
| dev_name = torch.cuda.get_device_name(0) | |
| else: | |
| dev_name = "CPU" | |
| torch_ver = torch.__version__ | |
| return f"[{dev_name} | torch {torch_ver}]" | |
| def plot_result(r_path): | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| df = pd.read_csv(r_path + ".csv") | |
| plt.figure(figsize=(12, 6)) | |
| ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca()) | |
| ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(), | |
| fontsize=14, | |
| fontweight="bold") | |
| ax.set_ylabel("Relative Speedup", fontsize=14) | |
| ax.set_xlabel("") | |
| plt.xticks(rotation=45, fontsize=12, ha="right", rotation_mode="anchor") | |
| for container in ax.containers: | |
| labels = [f"x{v.get_height():.2f}" for v in container] | |
| ax.bar_label(container, labels=labels, label_type="edge", fontsize=10) | |
| plt.tight_layout() | |
| plt.savefig(r_path + ".png", bbox_inches="tight") | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--case", | |
| choices=["rms", "add_rms", "poly", "mul_poly"], | |
| required=True) | |
| ap.add_argument("--plot", action="store_true") | |
| ap.add_argument( | |
| "--save-path", | |
| type=str, | |
| default="./configs/", | |
| help="Path to save benchmark results", | |
| ) | |
| args = ap.parse_args() | |
| torch.set_default_device("cuda") | |
| mod = importlib.import_module(f"cases.{args.case}") | |
| case: DiffCase = mod.CASE | |
| calculate_diff( | |
| case, | |
| batch_size=2, | |
| seq_len=128, | |
| hidden_size=4096, | |
| ) | |
| save_dir = os.path.join(args.save_path, args.case) | |
| if args.plot: | |
| batch_size_range = [1] | |
| seq_length_range = [4096, 8192, 16384] | |
| dim = [8192, 16384] if "poly" in args.case else [2048, 4096] | |
| configs = list( | |
| itertools.product(batch_size_range, seq_length_range, dim)) | |
| plot_name = f"plot_{args.case}-fwd-perf" | |
| bench = make_fwd_benchmark_plot_for_case( | |
| case=case, | |
| configs=configs, | |
| plot_name=plot_name, | |
| line_names={ | |
| "naive": "Naive", | |
| "cuda": "Cuda", | |
| }, | |
| ) | |
| bench.run(print_data=True, save_path=save_dir) | |
| plot_result(os.path.join(save_dir, plot_name)) | |
| plot_name = f"plot_{args.case}-bwd-perf" | |
| bench = make_bwd_benchmark_plot_for_case( | |
| case=case, | |
| configs=configs, | |
| plot_name=plot_name, | |
| line_names={ | |
| "naive": "Naive", | |
| "cuda": "Cuda", | |
| }, | |
| ) | |
| bench.run(print_data=True, save_path=save_dir) | |
| plot_result(os.path.join(save_dir, plot_name)) | |
| for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( | |
| os.path.join(save_dir, "*.csv")): | |
| os.remove(f) | |
| else: | |
| batch_size_range = [2**i for i in range(0, 4, 1)] | |
| seq_length_range = [2**i for i in range(10, 14, 1)] | |
| dim = [8192, 16384] if "poly" in args.case else [2048, 4096] | |
| configs = list( | |
| itertools.product(dim, batch_size_range, seq_length_range)) | |
| bench = make_fwd_benchmark_for_case( | |
| case=case, | |
| configs=configs, | |
| plot_name=f"{args.case}-fwd-perf", | |
| line_names={ | |
| "naive": "Naive", | |
| "cuda": "Cuda", | |
| "speedup": "SpeedUp" | |
| }, | |
| ) | |
| bench.run(print_data=True, save_path=save_dir) | |
| bench = make_bwd_benchmark_for_case( | |
| case=case, | |
| configs=configs, | |
| plot_name=f"{args.case}-bwd-perf", | |
| line_names={ | |
| "naive": "Naive", | |
| "cuda": "Cuda", | |
| "speedup": "SpeedUp" | |
| }, | |
| ) | |
| bench.run(print_data=True, save_path=save_dir) | |
| for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( | |
| os.path.join(save_dir, "*.png")): | |
| os.remove(f) | |
| if __name__ == "__main__": | |
| main() | |