| import torch | |
| from fused_ssim import fused_ssim | |
| from pytorch_msssim import SSIM | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import time | |
| import os | |
| plt.style.use('ggplot') | |
| gpu = torch.cuda.get_device_name() | |
| if __name__ == "__main__": | |
| torch.manual_seed(0) | |
| B, CH = 5, 1 | |
| dimensions = list(range(50, 1550, 50)) | |
| iterations = 50 | |
| data = { | |
| "pytorch_mssim": [], | |
| "fused-ssim": [] | |
| } | |
| pm_ssim = SSIM(data_range=1.0, channel=CH) | |
| for d in dimensions: | |
| with torch.no_grad(): | |
| img1_og = torch.rand([B, CH, d, d], device="cuda") | |
| img2_og = torch.rand([B, CH, d, d], device="cuda") | |
| img1_mine_same = torch.nn.Parameter(img1_og.clone()) | |
| img2_mine_same = img2_og.clone() | |
| img1_pm = torch.nn.Parameter(img1_og.clone()) | |
| img2_pm = img2_og.clone() | |
| begin = time.time() | |
| for _ in range(iterations): | |
| pm_ssim_val = pm_ssim(img1_pm, img2_pm) | |
| pm_ssim_val.backward() | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| data["pytorch_mssim"].append((end - begin) / iterations * 1000) | |
| begin = time.time() | |
| for _ in range(iterations): | |
| mine_ssim_val_same = fused_ssim(img1_mine_same, img2_mine_same) | |
| mine_ssim_val_same.backward() | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| data["fused-ssim"].append((end - begin) / iterations * 1000) | |
| num_pixels = (B * np.array(dimensions) ** 2) / 1e6 | |
| plt.plot(num_pixels, data["pytorch_mssim"], label="pytorch_mssim") | |
| plt.plot(num_pixels, data["fused-ssim"], label="fused-ssim") | |
| plt.legend() | |
| plt.xlabel("Number of pixels (in millions).") | |
| plt.ylabel("Time for one training iteration (ms).") | |
| plt.title(f"Training Benchmark on {gpu}.") | |
| plt.savefig(os.path.join("..", "images", "training_time.png"), dpi=300) | |
| data = { | |
| "pytorch_mssim": [], | |
| "fused-ssim": [] | |
| } | |
| plt.clf() | |
| for d in dimensions: | |
| with torch.no_grad(): | |
| img1_og = torch.rand([B, CH, d, d], device="cuda") | |
| img2_og = torch.rand([B, CH, d, d], device="cuda") | |
| img1_mine_same = torch.nn.Parameter(img1_og.clone()) | |
| img2_mine_same = img2_og.clone() | |
| img1_pm = torch.nn.Parameter(img1_og.clone()) | |
| img2_pm = img2_og.clone() | |
| begin = time.time() | |
| for _ in range(iterations): | |
| pm_ssim_val = pm_ssim(img1_pm, img2_pm) | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| data["pytorch_mssim"].append((end - begin) / iterations * 1000) | |
| begin = time.time() | |
| for _ in range(iterations): | |
| mine_ssim_val_same = fused_ssim(img1_mine_same, img2_mine_same, train=False) | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| data["fused-ssim"].append((end - begin) / iterations * 1000) | |
| num_pixels = (B * np.array(dimensions) ** 2) / 1e6 | |
| plt.plot(num_pixels, data["pytorch_mssim"], label="pytorch_mssim") | |
| plt.plot(num_pixels, data["fused-ssim"], label="fused-ssim") | |
| plt.legend() | |
| plt.xlabel("Number of pixels (in millions).") | |
| plt.ylabel("Time for one inference iteration (ms).") | |
| plt.title(f"Inference Benchmark on {gpu}.") | |
| plt.savefig(os.path.join("..", "images", "inference_time.png"), dpi=300) | |