File size: 3,387 Bytes
789eef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from fused_ssim import fused_ssim
from pytorch_msssim import SSIM
import matplotlib.pyplot as plt
import numpy as np
import time
import os

plt.style.use('ggplot')
gpu = torch.cuda.get_device_name()

if __name__ == "__main__":
    torch.manual_seed(0)

    B, CH = 5, 1
    dimensions = list(range(50, 1550, 50))
    iterations = 50

    data = {
        "pytorch_mssim": [],
        "fused-ssim": []
    }

    pm_ssim = SSIM(data_range=1.0, channel=CH)

    for d in dimensions:
        with torch.no_grad():
            img1_og = torch.rand([B, CH, d, d], device="cuda")
            img2_og = torch.rand([B, CH, d, d], device="cuda")

            img1_mine_same = torch.nn.Parameter(img1_og.clone())
            img2_mine_same = img2_og.clone()

            img1_pm = torch.nn.Parameter(img1_og.clone())
            img2_pm = img2_og.clone()

        begin = time.time()
        for _ in range(iterations):
            pm_ssim_val = pm_ssim(img1_pm, img2_pm)
            pm_ssim_val.backward()
        torch.cuda.synchronize()
        end = time.time()
        data["pytorch_mssim"].append((end - begin) / iterations * 1000)

        begin = time.time()
        for _ in range(iterations):
            mine_ssim_val_same = fused_ssim(img1_mine_same, img2_mine_same)
            mine_ssim_val_same.backward()
        torch.cuda.synchronize()
        end = time.time()
        data["fused-ssim"].append((end - begin) / iterations * 1000)

    num_pixels = (B * np.array(dimensions) ** 2) / 1e6
    plt.plot(num_pixels, data["pytorch_mssim"], label="pytorch_mssim")
    plt.plot(num_pixels, data["fused-ssim"], label="fused-ssim")
    plt.legend()
    plt.xlabel("Number of pixels (in millions).")
    plt.ylabel("Time for one training iteration (ms).")
    plt.title(f"Training Benchmark on {gpu}.")
    plt.savefig(os.path.join("..", "images", "training_time.png"), dpi=300)

    data = {
        "pytorch_mssim": [],
        "fused-ssim": []
    }

    plt.clf()
    for d in dimensions:
        with torch.no_grad():
            img1_og = torch.rand([B, CH, d, d], device="cuda")
            img2_og = torch.rand([B, CH, d, d], device="cuda")

            img1_mine_same = torch.nn.Parameter(img1_og.clone())
            img2_mine_same = img2_og.clone()

            img1_pm = torch.nn.Parameter(img1_og.clone())
            img2_pm = img2_og.clone()

            begin = time.time()
            for _ in range(iterations):
                pm_ssim_val = pm_ssim(img1_pm, img2_pm)
            torch.cuda.synchronize()
            end = time.time()
            data["pytorch_mssim"].append((end - begin) / iterations * 1000)

            begin = time.time()
            for _ in range(iterations):
                mine_ssim_val_same = fused_ssim(img1_mine_same, img2_mine_same, train=False)
            torch.cuda.synchronize()
            end = time.time()
            data["fused-ssim"].append((end - begin) / iterations * 1000)

    num_pixels = (B * np.array(dimensions) ** 2) / 1e6
    plt.plot(num_pixels, data["pytorch_mssim"], label="pytorch_mssim")
    plt.plot(num_pixels, data["fused-ssim"], label="fused-ssim")
    plt.legend()
    plt.xlabel("Number of pixels (in millions).")
    plt.ylabel("Time for one inference iteration (ms).")
    plt.title(f"Inference Benchmark on {gpu}.")
    plt.savefig(os.path.join("..", "images", "inference_time.png"), dpi=300)