justuswill commited on
Commit
e7a2b01
·
verified ·
1 Parent(s): 86ddd89

Upload 3 files

Browse files
Files changed (3) hide show
  1. eval/cifar.yml +49 -0
  2. eval/imagenet.yml +61 -0
  3. eval/paper_plots.py +100 -0
eval/cifar.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rounded to 4 digits
2
+ jpeg:
3
+ bpp: [ 2.4004, 2.8381, 3.2233, 3.5125, 3.743, 3.9558, 4.1803, 4.5109, 5.0316, 6.1434, 7.5109, 10.7999 ]
4
+ psnr: [ 18.6467, 23.3137, 25.517, 26.7397, 27.572, 28.255, 28.9256, 29.8221, 31.1432, 33.6497, 36.4185, 41.9347 ]
5
+ fid: [ 224.9505, 112.9873, 75.5656, 57.9521, 47.1583, 39.2226, 32.5657, 24.8751, 16.5128, 7.61, 3.635, 2.4892 ]
6
+ jpeg2000:
7
+ bpp: [ 1.9891, 2.5018, 4.884, 8.0687, 12.0546, 21.8545 ]
8
+ psnr: [ 14.0417, 17.0844, 25.7253, 32.4832, 39.2544, .inf ]
9
+ fid: [ 334.3462, 222.3907, 70.642, 10.7621, 2.9819, 0.0 ]
10
+ bpg:
11
+ bpp: [ 0.3304, 0.7377, 1.737, 3.5282, 6.03, 8.9413, 14.2088 ]
12
+ psnr: [ 19.2975, 25.0311, 30.9917, 36.815, 40.5962, 41.9634, .inf ]
13
+ fid: [ 176.8806, 67.5432, 19.8662, 4.1032, 1.6116, 1.481, 0.0 ]
14
+ vdm-d:
15
+ bpp: [ 0.0001, 0.0004, 0.001, 0.0024, 0.0062, 0.0162, 0.0431, 0.129, 0.4065, 1.0298, 1.9598, 3.0714, 4.3022, 5.644, 7.1119, 8.7178, 10.4483, 12.3011, 14.3031, 16.3632, 17.5816 ]
16
+ psnr: [ 8.3897, 10.7161, 11.9383, 12.8162, 13.8472, 15.2244, 16.9644, 19.0609, 21.5213, 24.3224, 27.3847, 30.6517, 34.1026, 37.72, 41.5499, 45.6346, 49.6708, 53.4106, 58.1225, 70.9762, .inf ]
17
+ fid: [ 337.6811, 306.4971, 319.1342, 318.8984, 327.4887, 285.7153, 154.9047, 82.0968, 51.0361, 33.0263, 21.2701, 12.7124, 7.0161, 3.5962, 1.5872, 0.574, 0.2017, 0.072, 0.0213, 0.0011, 0.0 ]
18
+ vdm-a:
19
+ bpp: [ 0.0001, 0.0004, 0.001, 0.0024, 0.0062, 0.0161, 0.0431, 0.1289, 0.4066, 1.0299, 1.9599, 3.0717, 4.3024, 5.6444, 7.1126, 8.7188, 10.4493, 12.302, 14.3048, 16.3645, 17.5823 ]
20
+ psnr: [ 10.5164, 10.6789, 10.9824, 11.606, 12.7016, 14.1219, 15.9098, 17.9932, 20.4458, 23.2081, 26.2264, 29.4655, 32.8867, 36.4988, 40.3707, 44.5329, 48.6745, 52.587, 57.8859, 70.9822, .inf ]
21
+ fid: [ 82.7664, 83.0947, 83.1105, 82.5617, 82.336, 78.6079, 73.4927, 64.3986, 50.9205, 35.9904, 22.6822, 12.5784, 6.5951, 3.2898, 1.4879, 0.5702, 0.2174, 0.084, 0.0226, 0.0011, 0.0 ]
22
+ vdm-f:
23
+ bpp: [ 0.0001, 0.0004, 0.001, 0.0024, 0.0062, 0.0162, 0.0431, 0.1289, 0.4059, 1.0284, 1.9573, 3.0677, 4.2964, 5.6371, 7.1034, 8.7071, 10.4352, 12.285, 14.2854, 16.3431, 17.5594 ]
24
+ psnr: [ 10.3074, 10.706, 11.1598, 11.9164, 13.0177, 14.5288, 16.4034, 18.6046, 21.1607, 24.0178, 27.1459, 30.4554, 33.9089, 37.5705, 41.469, 45.5837, 49.5898, 53.2924, 58.484, 71.3887, .inf ]
25
+ fid: [ 70.5077, 66.8462, 61.5901, 57.3662, 54.1494, 50.5307, 46.4654, 39.5295, 30.6986, 20.7308, 12.469, 6.9301, 3.7285, 1.9744, 0.9323, 0.3961, 0.1631, 0.0693, 0.02, 0.001, 0.0 ]
26
+ vdm-1000d:
27
+ bpp: [ 0.0001, 0.0012, 0.0057, 0.0315, 0.1422, 0.489, 1.2785, 2.7981, 5.181, 8.4341, 9.7324 ]
28
+ psnr: [ 8.426, 11.9452, 13.8502, 16.9632, 21.5297, 27.3811, 34.0989, 41.5468, 49.6702, 58.1181, .inf ]
29
+ fid: [ 344.1216, 313.8488, 273.115, 167.7786, 102.622, 56.8647, 27.4678, 9.96, 2.6366, 0.5639, 0.1203 ]
30
+ uqdm-d:
31
+ bpp: [ 0.0, 0.8459, 2.7828, 7.6704, 15.3017 ]
32
+ psnr: [ 11.5512, 21.8842, 31.175, 43.7569, 59.9976 ]
33
+ fid: [ 348.7675, 47.7964, 11.9617, 0.8906, 0.0234 ]
34
+ uqdm-a:
35
+ bpp: [ 0.0, 0.8448, 2.7813, 7.6689, 15.3002 ]
36
+ psnr: [ 10.043, 19.0172, 28.2593, 40.5961, 56.1671 ]
37
+ fid: [ 315.5473, 126.366, 36.134, 2.3176, 0.0698 ]
38
+ uqdm-f:
39
+ bpp: [0.0, 0.8347, 2.7833, 7.6738, 15.3061]
40
+ psnr: [9.2672, 21.792, 31.1732, 43.7335, 60.0059]
41
+ fid: [324.4288, 36.2628, 9.317, 0.6881, 0.0235]
42
+ vae-b:
43
+ bpp: [ 0.0269, 0.2894, 1.0495, 2.7783, 4.053, 5.5443, 9.5346 ]
44
+ psnr: [ 16.1589, 22.9358, 29.6548, 37.4456, 41.5039, 45.5439, 53.3573 ]
45
+ fid: [ 322.1337, 105.1356, 32.7888, 5.0927, 1.8276, 0.5906, 0.0638 ]
46
+ vae-m:
47
+ bpp: [ 0.0568, 0.3355, 1.1218, 2.8024, 4.0976, 5.5526, 9.3886 ]
48
+ psnr: [ 17.9086, 24.2149, 30.7686, 38.124, 41.9559, 45.8761, 54.0515 ]
49
+ fid: [ 140.5861, 70.7924, 22.6757, 4.3984, 1.6444, 0.532, 0.0544 ]
eval/imagenet.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rounded to 4 digits
2
+ jpeg:
3
+ bpp: [ 0.7018, 1.0361, 1.3602, 1.6185, 1.8293, 2.0268, 2.2384, 2.5554, 3.0622, 4.1426, 5.4319, 8.6382 ]
4
+ psnr: [ 19.1856, 23.4067, 25.3096, 26.3807, 27.1229, 27.744, 28.3645, 29.2213, 30.5114, 32.9266, 35.2304, 38.5133 ]
5
+ fid: [ 188.6707, 58.9173, 33.6194, 24.6107, 19.6276, 16.4469, 13.8417, 11.0175, 7.9301, 4.4655, 3.0617, 2.0701 ]
6
+ jpeg2000:
7
+ bpp: [ 1.2229, 2.4075, 4.7929, 7.9738, 11.9161, 17.7443 ]
8
+ psnr: [ 22.6253, 26.8887, 32.7778, 38.807, 44.8636, .inf ]
9
+ fid: [ 97.6303, 30.7883, 7.1988, 1.4258, 0.3257, 0.0 ]
10
+ bpg:
11
+ bpp: [ 0.1342, 0.4744, 1.4261, 3.088, 5.3774, 8.2017, 13.2612 ]
12
+ psnr: [ 20.1191, 25.3813, 30.9028, 35.5346, 37.8201, 38.7962, .inf ]
13
+ fid: [ 119.4003, 47.5895, 9.8001, 2.9095, 1.9414, 1.657, 0.0 ]
14
+ vdm-d:
15
+ bpp: [ 0.0003, 0.0008, 0.002, 0.0046, 0.0101, 0.0224, 0.0529, 0.1411, 0.3803, 0.8563, 1.5483, 2.3845, 3.3193, 4.3402, 5.4513, 6.6648, 7.997, 9.4661, 11.0878, 12.8641, 14.7471 ]
16
+ psnr: [ 11.5434, 12.3961, 13.3009, 14.3916, 15.6468, 17.1123, 18.7804, 20.6163, 22.657, 24.9221, 27.3765, 30.0165, 32.8282, 35.7928, 38.8757, 42.0351, 45.2283, 48.4134, 51.5481, 54.9842, 60.0558 ]
17
+ fid: [ 321.9626, 319.7044, 289.2591, 262.9117, 239.1254, 183.5961, 120.7653, 82.5857, 54.025, 32.5723, 18.1992, 10.0922, 5.5587, 3.0785, 1.6479, 0.8395, 0.4086, 0.1921, 0.0831, 0.036, 0.0 ]
18
+ vdm-a:
19
+ bpp: [ 0.0003, 0.0008, 0.002, 0.0046, 0.0101, 0.0224, 0.0529, 0.1412, 0.3804, 0.8563, 1.5483, 2.3845, 3.3194, 4.3402, 5.4515, 6.6652, 7.9973, 9.4666, 11.0887, 12.8654, 14.7486 ]
20
+ psnr: [ 10.7752, 11.2343, 12.0064, 13.0485, 14.3389, 15.8163, 17.4941, 19.3574, 21.4164, 23.6685, 26.0841, 28.678, 31.4593, 34.4045, 37.4753, 40.6318, 43.8405, 47.0932, 50.3936, 53.9429, 58.2943 ]
21
+ fid: [ 132.1083, 128.8545, 125.7113, 119.9762, 112.6877, 102.679, 88.1364, 72.923, 54.3236, 35.9094, 20.2573, 10.1911, 5.1121, 2.6013, 1.3265, 0.6607, 0.3216, 0.1509, 0.0696, 0.0345, 0.0194 ]
22
+ vdm-f:
23
+ bpp: [ 0.0003, 0.0046, 0.0529, 0.3803, 1.5483, 2.3845, 3.3194, 4.3402, 5.4514, 6.665, 7.9971, 9.4664, 11.0883, 12.8647, 14.7479 ]
24
+ psnr: [ 10.4359, 13.0568, 17.5926, 21.7079, 26.6407, 29.3521, 32.228, 35.2574, 38.3862, 41.593, 44.8423, 48.0954, 51.2717, 54.5982, 60.0537 ]
25
+ fid: [ 114.5655, 98.7355, 73.8959, 42.8638, 11.3897, 4.3585, 1.6068, 0.6184, 0.2578, 0.1178, 0.0572, 0.0296, 0.018, 0.0119, 0.0059 ]
26
+ vdm-1000d:
27
+ bpp: [ 0.0003, 0.0018, 0.0086, 0.035, 0.1264, 0.3872, 0.9594, 2.0214, 3.7732, 6.421, 9.6478 ]
28
+ psnr: [ 11.56, 13.2967, 15.6984, 18.7889, 22.6692, 27.3418, 32.7914, 38.8265, 45.1906, 51.5336, 60.0486 ]
29
+ fid: [ 344.1216, 313.8488, 273.1115, 167.7786, 102.6219, 56.8647, 27.4678, 9.96, 2.2637, 0.5639, 0.1203 ]
30
+ uqdm-d:
31
+ bpp: [ 0.0, 0.5803, 2.1678, 7.1596, 15.1466 ]
32
+ psnr: [ 11.6448, 21.566, 30.0196, 42.7952, 58.9717 ]
33
+ fid: [ 315.7637, 70.5995, 9.5944, 0.4343, 0.0079 ]
34
+ uqdm-a:
35
+ bpp: [ 0.0, 0.5922, 2.2064, 7.1931, 15.18 ]
36
+ psnr: [ 9.812, 18.2937, 27.0955, 39.4883, 55.7446 ]
37
+ fid: [ 348.9983, 115.0491, 20.403, 0.9739, 0.0235 ]
38
+ uqdm-f:
39
+ bpp: [ 0.0, 0.5857, 2.1904, 7.1823, 15.1691 ]
40
+ psnr: [ 8.6462, 21.3866, 30.0129, 42.7483, 58.9717 ]
41
+ fid: [ 341.5454, 63.9474, 7.8778, 0.2446, 0.0079 ]
42
+ ctc:
43
+ bpp: [ 0.083, 0.3398, 0.6538, 1.398, 3.3254, 4.288 ]
44
+ psnr: [ 14.6201, 17.6903, 21.5264, 26.441, 32.2562, 33.6447 ]
45
+ fid: [ 233.5478, 151.6671, 86.6849, 29.4421, 4.912, 3.3053 ]
46
+ cdc-0:
47
+ bpp: [ 0.4123, 0.6286, 0.8253, 1.2664, 3.7378, 5.1776, 5.2558 ]
48
+ psnr: [ 24.5071, 26.1107, 26.9455, 27.7432, 28.583, 29.3795, 29.5278 ]
49
+ fid: [ 37.1236, 24.621, 19.252, 16.7105, 12.7257, 10.1391, 9.6735 ]
50
+ cdc-p:
51
+ bpp: [ 0.3722, 0.5957, 0.8319, 3.9075, 16.9651, 17.645 ]
52
+ psnr: [ 21.2959, 23.5755, 24.9322, 26.1104, 26.1137, 26.1169 ]
53
+ fid: [ 19.2056, 12.5911, 9.608, 5.631, 5.6307, 5.5274 ]
54
+ vae-b:
55
+ bpp: [ 0.0271, 0.2135, 0.9105, 2.6478, 3.97, 5.7155, 10.4487 ]
56
+ psnr: [ 17.8252, 23.4422, 29.528, 37.0424, 40.9555, 44.9634, 53.1489 ]
57
+ fid: [ 197.7646, 79.2116, 22.8542, 3.3213, 1.0492, 0.2569, 0.011 ]
58
+ vae-m:
59
+ bpp: [ 0.0463, 0.2319, 0.9435, 2.6222, 3.8701, 5.53, 10.1296 ]
60
+ psnr: [ 19.6651, 24.3748, 30.5472, 37.8518, 41.5246, 45.3234, 53.8203 ]
61
+ fid: [ 163.125, 67.5424, 15.861, 1.8724, 0.565, 0.139, 0.0109 ]
eval/paper_plots.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ import yaml
4
+
5
+ """
6
+ Script to create the plots in 'Progressive Compression with Universally Quantized Diffusion Models', Yang et al., 2025.
7
+ """
8
+
9
+
10
+ def rd_fid(dataset='imagenet', baselines=None):
11
+ """
12
+ Create R-D and R-FID curves from precomputed models nad baseline,
13
+ evaluated on 10,000 images from the de-duplicated evaluation set.
14
+
15
+ Inputs:
16
+ -------
17
+ dataset: 'cifar' or 'imagenet'
18
+ baselines: (optional) list of baselines from
19
+ 'uqdm' or 'uqdm-d', 'uqdm-a', 'uqdm-f' - our model via (d)enoising, (a)ncestral, or (f)low-based sampling
20
+ 'vdm' or 'vdm-d', 'vdm-a', 'vdm-f', 'vdm-1000d' - theoretical results of Gaussian diffusion
21
+ 'jpeg', 'jpeg2000', 'bpg' - wavelet-based traditional codecs
22
+ 'ctc', - progressive neural codec via hierarchically quantized latent space (Jeon et al., 2023)
23
+ 'cdc' or 'cdc-0', 'cdc-p' - non-progressive neural codec with conditional diffusion model (Yang et al., 2023)
24
+ 'vae' or 'vae-b', 'vae-m' - non-progressive neural codec with VAE (Ballé et al., 2018) or (Minnen et al., 2020)
25
+ save: (optional) filename to save plot to
26
+ """
27
+
28
+ # Load results and select baselines
29
+ with open('%s.yml' % dataset, 'r') as f:
30
+ results = yaml.safe_load(f)
31
+ if baselines is None:
32
+ baselines = ['jpeg', 'jpeg2000', 'bpg', 'ctc', 'cdc', 'vdm', 'uqdm']
33
+ if 'cdc' in baselines:
34
+ baselines += ['cdc-0', 'cdc-p']
35
+ if 'vae' in baselines:
36
+ baselines += ['vae-b', 'vae-m']
37
+ if 'vdm' in baselines:
38
+ baselines += ['vdm-1000d', 'vdm-d', 'vdm-a', 'vdm-f']
39
+ if 'uqdm' in baselines:
40
+ baselines += ['uqdm-d', 'uqdm-a', 'uqdm-f']
41
+ baselines = [b for b in baselines if b not in ['uqdm', 'vdm', 'cdc', 'vae'] and b in results.keys()]
42
+
43
+ # Style setting
44
+ pl_kwargs = {'alpha': 0.8, 'lw': 2}
45
+ pl_styles = {
46
+ 'uqdm-d': dict(ls='-+', color='darkorange', label='UQDM T=4, denoise'),
47
+ 'uqdm-a': dict(ls='--x', color='darkorange', label='UQDM T=4, ancestral'),
48
+ 'uqdm-f': dict(ls=':x', color='darkorange', label='UQDM T=4, flow-based'),
49
+ 'vdm-d': dict(ls='-+', color='blue', label='VDM T=20, denoise', alpha=0.6, lw=1.5),
50
+ 'vdm-a': dict(ls='--x', color='blue', label='VDM T=20, ancestral', alpha=0.6, lw=1.5),
51
+ 'vdm-f': dict(ls=':x', color='blue', label='VDM T=20, flow-based', alpha=0.6, lw=1.5),
52
+ 'vdm-1000d': dict(ls=':+', color='darkturquoise', label='VDM T=1000, denoise', alpha=0.6, lw=1.5),
53
+ 'jpeg': dict(ls='-.+', color='red', label='JPEG'),
54
+ 'jpeg2000': dict(ls='-x', color='red', label='JPEG2000'),
55
+ 'bpg': dict(ls='-x', color='sienna', label='BPG'),
56
+ 'ctc': dict(ls='-x', color='fuchsia', label='CTC'),
57
+ 'cdc-0': dict(ls='-x', color='green', label='CDC (p=0)'),
58
+ 'cdc-p': dict(ls='-.x', color='green', label='CDC (p=0.9)'),
59
+ 'vae-b': dict(ls='--x', color='limegreen', label='VAE (Ballé 2018)'),
60
+ 'vae-m': dict(ls='-+', color='limegreen', label='VAE (Minnen 2020)'),
61
+ }
62
+ sns.set_style('whitegrid')
63
+
64
+ # Plots
65
+ textwidth = 5.5206 * 2.5
66
+ fig_rd, ax_rd = plt.subplots(figsize=(0.45 * textwidth, 0.36 * textwidth))
67
+ fig_fid, ax_fid = plt.subplots(figsize=(0.45 * textwidth, 0.36 * textwidth))
68
+ for b in baselines:
69
+ bpp, psnr, fid = results[b]['bpp'], results[b]['psnr'], results[b]['fid']
70
+ kwargs = pl_kwargs | pl_styles[b]
71
+ ls = kwargs.pop('ls', None)
72
+ ax_rd.plot(bpp, psnr, ls, **kwargs)
73
+ ax_fid.plot(bpp, fid, ls, **kwargs)
74
+ ax_rd.legend(loc='lower right')
75
+ ax_fid.legend(loc='upper right')
76
+ ax_rd.set(xlabel='Rate (bpp)', ylabel='PSNR (dB)')
77
+ ax_fid.set(xlabel='Rate (bpp)', ylabel='FID')
78
+ ax_rd.grid(visible=True)
79
+ ax_fid.grid(visible=True)
80
+ ax_fid.set_yscale('symlog')
81
+ fig_rd.tight_layout()
82
+ fig_fid.tight_layout()
83
+ fig_rd.savefig('tmp_rd.png', bbox_inches='tight', pad_inches=0, dpi=600)
84
+ fig_fid.savefig('tmp_fid.png', bbox_inches='tight', pad_inches=0, dpi=600)
85
+ plt.show()
86
+
87
+
88
+ if __name__ == '__main__':
89
+ # Rate-distortion, Rate-Realism
90
+ rd_fid(dataset='cifar')
91
+ rd_fid(dataset='imagenet')
92
+ # Plots from the slides
93
+ # Gaussian vs Uniform
94
+ # rd_fid(dataset='imagenet', baselines=['vdm', 'uqdm'])
95
+ # Traditional Baselines
96
+ # rd_fid(dataset='imagenet', baselines=['jpeg', 'jpeg2000', 'bpg', 'uqdm'])
97
+ # Neural Baselines
98
+ # rd_fid(dataset='imagenet', baselines=['ctc', 'cdc', 'vae', 'uqdm'])
99
+ # Progressive Baselines
100
+ # rd_fid(dataset='imagenet', baselines=['jpeg2000', 'ctc', 'uqdm'])