Upload 3 files
Browse files- eval/cifar.yml +49 -0
- eval/imagenet.yml +61 -0
- eval/paper_plots.py +100 -0
eval/cifar.yml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rounded to 4 digits
|
| 2 |
+
jpeg:
|
| 3 |
+
bpp: [ 2.4004, 2.8381, 3.2233, 3.5125, 3.743, 3.9558, 4.1803, 4.5109, 5.0316, 6.1434, 7.5109, 10.7999 ]
|
| 4 |
+
psnr: [ 18.6467, 23.3137, 25.517, 26.7397, 27.572, 28.255, 28.9256, 29.8221, 31.1432, 33.6497, 36.4185, 41.9347 ]
|
| 5 |
+
fid: [ 224.9505, 112.9873, 75.5656, 57.9521, 47.1583, 39.2226, 32.5657, 24.8751, 16.5128, 7.61, 3.635, 2.4892 ]
|
| 6 |
+
jpeg2000:
|
| 7 |
+
bpp: [ 1.9891, 2.5018, 4.884, 8.0687, 12.0546, 21.8545 ]
|
| 8 |
+
psnr: [ 14.0417, 17.0844, 25.7253, 32.4832, 39.2544, .inf ]
|
| 9 |
+
fid: [ 334.3462, 222.3907, 70.642, 10.7621, 2.9819, 0.0 ]
|
| 10 |
+
bpg:
|
| 11 |
+
bpp: [ 0.3304, 0.7377, 1.737, 3.5282, 6.03, 8.9413, 14.2088 ]
|
| 12 |
+
psnr: [ 19.2975, 25.0311, 30.9917, 36.815, 40.5962, 41.9634, .inf ]
|
| 13 |
+
fid: [ 176.8806, 67.5432, 19.8662, 4.1032, 1.6116, 1.481, 0.0 ]
|
| 14 |
+
vdm-d:
|
| 15 |
+
bpp: [ 0.0001, 0.0004, 0.001, 0.0024, 0.0062, 0.0162, 0.0431, 0.129, 0.4065, 1.0298, 1.9598, 3.0714, 4.3022, 5.644, 7.1119, 8.7178, 10.4483, 12.3011, 14.3031, 16.3632, 17.5816 ]
|
| 16 |
+
psnr: [ 8.3897, 10.7161, 11.9383, 12.8162, 13.8472, 15.2244, 16.9644, 19.0609, 21.5213, 24.3224, 27.3847, 30.6517, 34.1026, 37.72, 41.5499, 45.6346, 49.6708, 53.4106, 58.1225, 70.9762, .inf ]
|
| 17 |
+
fid: [ 337.6811, 306.4971, 319.1342, 318.8984, 327.4887, 285.7153, 154.9047, 82.0968, 51.0361, 33.0263, 21.2701, 12.7124, 7.0161, 3.5962, 1.5872, 0.574, 0.2017, 0.072, 0.0213, 0.0011, 0.0 ]
|
| 18 |
+
vdm-a:
|
| 19 |
+
bpp: [ 0.0001, 0.0004, 0.001, 0.0024, 0.0062, 0.0161, 0.0431, 0.1289, 0.4066, 1.0299, 1.9599, 3.0717, 4.3024, 5.6444, 7.1126, 8.7188, 10.4493, 12.302, 14.3048, 16.3645, 17.5823 ]
|
| 20 |
+
psnr: [ 10.5164, 10.6789, 10.9824, 11.606, 12.7016, 14.1219, 15.9098, 17.9932, 20.4458, 23.2081, 26.2264, 29.4655, 32.8867, 36.4988, 40.3707, 44.5329, 48.6745, 52.587, 57.8859, 70.9822, .inf ]
|
| 21 |
+
fid: [ 82.7664, 83.0947, 83.1105, 82.5617, 82.336, 78.6079, 73.4927, 64.3986, 50.9205, 35.9904, 22.6822, 12.5784, 6.5951, 3.2898, 1.4879, 0.5702, 0.2174, 0.084, 0.0226, 0.0011, 0.0 ]
|
| 22 |
+
vdm-f:
|
| 23 |
+
bpp: [ 0.0001, 0.0004, 0.001, 0.0024, 0.0062, 0.0162, 0.0431, 0.1289, 0.4059, 1.0284, 1.9573, 3.0677, 4.2964, 5.6371, 7.1034, 8.7071, 10.4352, 12.285, 14.2854, 16.3431, 17.5594 ]
|
| 24 |
+
psnr: [ 10.3074, 10.706, 11.1598, 11.9164, 13.0177, 14.5288, 16.4034, 18.6046, 21.1607, 24.0178, 27.1459, 30.4554, 33.9089, 37.5705, 41.469, 45.5837, 49.5898, 53.2924, 58.484, 71.3887, .inf ]
|
| 25 |
+
fid: [ 70.5077, 66.8462, 61.5901, 57.3662, 54.1494, 50.5307, 46.4654, 39.5295, 30.6986, 20.7308, 12.469, 6.9301, 3.7285, 1.9744, 0.9323, 0.3961, 0.1631, 0.0693, 0.02, 0.001, 0.0 ]
|
| 26 |
+
vdm-1000d:
|
| 27 |
+
bpp: [ 0.0001, 0.0012, 0.0057, 0.0315, 0.1422, 0.489, 1.2785, 2.7981, 5.181, 8.4341, 9.7324 ]
|
| 28 |
+
psnr: [ 8.426, 11.9452, 13.8502, 16.9632, 21.5297, 27.3811, 34.0989, 41.5468, 49.6702, 58.1181, .inf ]
|
| 29 |
+
fid: [ 344.1216, 313.8488, 273.115, 167.7786, 102.622, 56.8647, 27.4678, 9.96, 2.6366, 0.5639, 0.1203 ]
|
| 30 |
+
uqdm-d:
|
| 31 |
+
bpp: [ 0.0, 0.8459, 2.7828, 7.6704, 15.3017 ]
|
| 32 |
+
psnr: [ 11.5512, 21.8842, 31.175, 43.7569, 59.9976 ]
|
| 33 |
+
fid: [ 348.7675, 47.7964, 11.9617, 0.8906, 0.0234 ]
|
| 34 |
+
uqdm-a:
|
| 35 |
+
bpp: [ 0.0, 0.8448, 2.7813, 7.6689, 15.3002 ]
|
| 36 |
+
psnr: [ 10.043, 19.0172, 28.2593, 40.5961, 56.1671 ]
|
| 37 |
+
fid: [ 315.5473, 126.366, 36.134, 2.3176, 0.0698 ]
|
| 38 |
+
uqdm-f:
|
| 39 |
+
bpp: [0.0, 0.8347, 2.7833, 7.6738, 15.3061]
|
| 40 |
+
psnr: [9.2672, 21.792, 31.1732, 43.7335, 60.0059]
|
| 41 |
+
fid: [324.4288, 36.2628, 9.317, 0.6881, 0.0235]
|
| 42 |
+
vae-b:
|
| 43 |
+
bpp: [ 0.0269, 0.2894, 1.0495, 2.7783, 4.053, 5.5443, 9.5346 ]
|
| 44 |
+
psnr: [ 16.1589, 22.9358, 29.6548, 37.4456, 41.5039, 45.5439, 53.3573 ]
|
| 45 |
+
fid: [ 322.1337, 105.1356, 32.7888, 5.0927, 1.8276, 0.5906, 0.0638 ]
|
| 46 |
+
vae-m:
|
| 47 |
+
bpp: [ 0.0568, 0.3355, 1.1218, 2.8024, 4.0976, 5.5526, 9.3886 ]
|
| 48 |
+
psnr: [ 17.9086, 24.2149, 30.7686, 38.124, 41.9559, 45.8761, 54.0515 ]
|
| 49 |
+
fid: [ 140.5861, 70.7924, 22.6757, 4.3984, 1.6444, 0.532, 0.0544 ]
|
eval/imagenet.yml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rounded to 4 digits
|
| 2 |
+
jpeg:
|
| 3 |
+
bpp: [ 0.7018, 1.0361, 1.3602, 1.6185, 1.8293, 2.0268, 2.2384, 2.5554, 3.0622, 4.1426, 5.4319, 8.6382 ]
|
| 4 |
+
psnr: [ 19.1856, 23.4067, 25.3096, 26.3807, 27.1229, 27.744, 28.3645, 29.2213, 30.5114, 32.9266, 35.2304, 38.5133 ]
|
| 5 |
+
fid: [ 188.6707, 58.9173, 33.6194, 24.6107, 19.6276, 16.4469, 13.8417, 11.0175, 7.9301, 4.4655, 3.0617, 2.0701 ]
|
| 6 |
+
jpeg2000:
|
| 7 |
+
bpp: [ 1.2229, 2.4075, 4.7929, 7.9738, 11.9161, 17.7443 ]
|
| 8 |
+
psnr: [ 22.6253, 26.8887, 32.7778, 38.807, 44.8636, .inf ]
|
| 9 |
+
fid: [ 97.6303, 30.7883, 7.1988, 1.4258, 0.3257, 0.0 ]
|
| 10 |
+
bpg:
|
| 11 |
+
bpp: [ 0.1342, 0.4744, 1.4261, 3.088, 5.3774, 8.2017, 13.2612 ]
|
| 12 |
+
psnr: [ 20.1191, 25.3813, 30.9028, 35.5346, 37.8201, 38.7962, .inf ]
|
| 13 |
+
fid: [ 119.4003, 47.5895, 9.8001, 2.9095, 1.9414, 1.657, 0.0 ]
|
| 14 |
+
vdm-d:
|
| 15 |
+
bpp: [ 0.0003, 0.0008, 0.002, 0.0046, 0.0101, 0.0224, 0.0529, 0.1411, 0.3803, 0.8563, 1.5483, 2.3845, 3.3193, 4.3402, 5.4513, 6.6648, 7.997, 9.4661, 11.0878, 12.8641, 14.7471 ]
|
| 16 |
+
psnr: [ 11.5434, 12.3961, 13.3009, 14.3916, 15.6468, 17.1123, 18.7804, 20.6163, 22.657, 24.9221, 27.3765, 30.0165, 32.8282, 35.7928, 38.8757, 42.0351, 45.2283, 48.4134, 51.5481, 54.9842, 60.0558 ]
|
| 17 |
+
fid: [ 321.9626, 319.7044, 289.2591, 262.9117, 239.1254, 183.5961, 120.7653, 82.5857, 54.025, 32.5723, 18.1992, 10.0922, 5.5587, 3.0785, 1.6479, 0.8395, 0.4086, 0.1921, 0.0831, 0.036, 0.0 ]
|
| 18 |
+
vdm-a:
|
| 19 |
+
bpp: [ 0.0003, 0.0008, 0.002, 0.0046, 0.0101, 0.0224, 0.0529, 0.1412, 0.3804, 0.8563, 1.5483, 2.3845, 3.3194, 4.3402, 5.4515, 6.6652, 7.9973, 9.4666, 11.0887, 12.8654, 14.7486 ]
|
| 20 |
+
psnr: [ 10.7752, 11.2343, 12.0064, 13.0485, 14.3389, 15.8163, 17.4941, 19.3574, 21.4164, 23.6685, 26.0841, 28.678, 31.4593, 34.4045, 37.4753, 40.6318, 43.8405, 47.0932, 50.3936, 53.9429, 58.2943 ]
|
| 21 |
+
fid: [ 132.1083, 128.8545, 125.7113, 119.9762, 112.6877, 102.679, 88.1364, 72.923, 54.3236, 35.9094, 20.2573, 10.1911, 5.1121, 2.6013, 1.3265, 0.6607, 0.3216, 0.1509, 0.0696, 0.0345, 0.0194 ]
|
| 22 |
+
vdm-f:
|
| 23 |
+
bpp: [ 0.0003, 0.0046, 0.0529, 0.3803, 1.5483, 2.3845, 3.3194, 4.3402, 5.4514, 6.665, 7.9971, 9.4664, 11.0883, 12.8647, 14.7479 ]
|
| 24 |
+
psnr: [ 10.4359, 13.0568, 17.5926, 21.7079, 26.6407, 29.3521, 32.228, 35.2574, 38.3862, 41.593, 44.8423, 48.0954, 51.2717, 54.5982, 60.0537 ]
|
| 25 |
+
fid: [ 114.5655, 98.7355, 73.8959, 42.8638, 11.3897, 4.3585, 1.6068, 0.6184, 0.2578, 0.1178, 0.0572, 0.0296, 0.018, 0.0119, 0.0059 ]
|
| 26 |
+
vdm-1000d:
|
| 27 |
+
bpp: [ 0.0003, 0.0018, 0.0086, 0.035, 0.1264, 0.3872, 0.9594, 2.0214, 3.7732, 6.421, 9.6478 ]
|
| 28 |
+
psnr: [ 11.56, 13.2967, 15.6984, 18.7889, 22.6692, 27.3418, 32.7914, 38.8265, 45.1906, 51.5336, 60.0486 ]
|
| 29 |
+
fid: [ 344.1216, 313.8488, 273.1115, 167.7786, 102.6219, 56.8647, 27.4678, 9.96, 2.2637, 0.5639, 0.1203 ]
|
| 30 |
+
uqdm-d:
|
| 31 |
+
bpp: [ 0.0, 0.5803, 2.1678, 7.1596, 15.1466 ]
|
| 32 |
+
psnr: [ 11.6448, 21.566, 30.0196, 42.7952, 58.9717 ]
|
| 33 |
+
fid: [ 315.7637, 70.5995, 9.5944, 0.4343, 0.0079 ]
|
| 34 |
+
uqdm-a:
|
| 35 |
+
bpp: [ 0.0, 0.5922, 2.2064, 7.1931, 15.18 ]
|
| 36 |
+
psnr: [ 9.812, 18.2937, 27.0955, 39.4883, 55.7446 ]
|
| 37 |
+
fid: [ 348.9983, 115.0491, 20.403, 0.9739, 0.0235 ]
|
| 38 |
+
uqdm-f:
|
| 39 |
+
bpp: [ 0.0, 0.5857, 2.1904, 7.1823, 15.1691 ]
|
| 40 |
+
psnr: [ 8.6462, 21.3866, 30.0129, 42.7483, 58.9717 ]
|
| 41 |
+
fid: [ 341.5454, 63.9474, 7.8778, 0.2446, 0.0079 ]
|
| 42 |
+
ctc:
|
| 43 |
+
bpp: [ 0.083, 0.3398, 0.6538, 1.398, 3.3254, 4.288 ]
|
| 44 |
+
psnr: [ 14.6201, 17.6903, 21.5264, 26.441, 32.2562, 33.6447 ]
|
| 45 |
+
fid: [ 233.5478, 151.6671, 86.6849, 29.4421, 4.912, 3.3053 ]
|
| 46 |
+
cdc-0:
|
| 47 |
+
bpp: [ 0.4123, 0.6286, 0.8253, 1.2664, 3.7378, 5.1776, 5.2558 ]
|
| 48 |
+
psnr: [ 24.5071, 26.1107, 26.9455, 27.7432, 28.583, 29.3795, 29.5278 ]
|
| 49 |
+
fid: [ 37.1236, 24.621, 19.252, 16.7105, 12.7257, 10.1391, 9.6735 ]
|
| 50 |
+
cdc-p:
|
| 51 |
+
bpp: [ 0.3722, 0.5957, 0.8319, 3.9075, 16.9651, 17.645 ]
|
| 52 |
+
psnr: [ 21.2959, 23.5755, 24.9322, 26.1104, 26.1137, 26.1169 ]
|
| 53 |
+
fid: [ 19.2056, 12.5911, 9.608, 5.631, 5.6307, 5.5274 ]
|
| 54 |
+
vae-b:
|
| 55 |
+
bpp: [ 0.0271, 0.2135, 0.9105, 2.6478, 3.97, 5.7155, 10.4487 ]
|
| 56 |
+
psnr: [ 17.8252, 23.4422, 29.528, 37.0424, 40.9555, 44.9634, 53.1489 ]
|
| 57 |
+
fid: [ 197.7646, 79.2116, 22.8542, 3.3213, 1.0492, 0.2569, 0.011 ]
|
| 58 |
+
vae-m:
|
| 59 |
+
bpp: [ 0.0463, 0.2319, 0.9435, 2.6222, 3.8701, 5.53, 10.1296 ]
|
| 60 |
+
psnr: [ 19.6651, 24.3748, 30.5472, 37.8518, 41.5246, 45.3234, 53.8203 ]
|
| 61 |
+
fid: [ 163.125, 67.5424, 15.861, 1.8724, 0.565, 0.139, 0.0109 ]
|
eval/paper_plots.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import yaml
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
Script to create the plots in 'Progressive Compression with Universally Quantized Diffusion Models', Yang et al., 2025.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def rd_fid(dataset='imagenet', baselines=None):
|
| 11 |
+
"""
|
| 12 |
+
Create R-D and R-FID curves from precomputed models nad baseline,
|
| 13 |
+
evaluated on 10,000 images from the de-duplicated evaluation set.
|
| 14 |
+
|
| 15 |
+
Inputs:
|
| 16 |
+
-------
|
| 17 |
+
dataset: 'cifar' or 'imagenet'
|
| 18 |
+
baselines: (optional) list of baselines from
|
| 19 |
+
'uqdm' or 'uqdm-d', 'uqdm-a', 'uqdm-f' - our model via (d)enoising, (a)ncestral, or (f)low-based sampling
|
| 20 |
+
'vdm' or 'vdm-d', 'vdm-a', 'vdm-f', 'vdm-1000d' - theoretical results of Gaussian diffusion
|
| 21 |
+
'jpeg', 'jpeg2000', 'bpg' - wavelet-based traditional codecs
|
| 22 |
+
'ctc', - progressive neural codec via hierarchically quantized latent space (Jeon et al., 2023)
|
| 23 |
+
'cdc' or 'cdc-0', 'cdc-p' - non-progressive neural codec with conditional diffusion model (Yang et al., 2023)
|
| 24 |
+
'vae' or 'vae-b', 'vae-m' - non-progressive neural codec with VAE (Ballé et al., 2018) or (Minnen et al., 2020)
|
| 25 |
+
save: (optional) filename to save plot to
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# Load results and select baselines
|
| 29 |
+
with open('%s.yml' % dataset, 'r') as f:
|
| 30 |
+
results = yaml.safe_load(f)
|
| 31 |
+
if baselines is None:
|
| 32 |
+
baselines = ['jpeg', 'jpeg2000', 'bpg', 'ctc', 'cdc', 'vdm', 'uqdm']
|
| 33 |
+
if 'cdc' in baselines:
|
| 34 |
+
baselines += ['cdc-0', 'cdc-p']
|
| 35 |
+
if 'vae' in baselines:
|
| 36 |
+
baselines += ['vae-b', 'vae-m']
|
| 37 |
+
if 'vdm' in baselines:
|
| 38 |
+
baselines += ['vdm-1000d', 'vdm-d', 'vdm-a', 'vdm-f']
|
| 39 |
+
if 'uqdm' in baselines:
|
| 40 |
+
baselines += ['uqdm-d', 'uqdm-a', 'uqdm-f']
|
| 41 |
+
baselines = [b for b in baselines if b not in ['uqdm', 'vdm', 'cdc', 'vae'] and b in results.keys()]
|
| 42 |
+
|
| 43 |
+
# Style setting
|
| 44 |
+
pl_kwargs = {'alpha': 0.8, 'lw': 2}
|
| 45 |
+
pl_styles = {
|
| 46 |
+
'uqdm-d': dict(ls='-+', color='darkorange', label='UQDM T=4, denoise'),
|
| 47 |
+
'uqdm-a': dict(ls='--x', color='darkorange', label='UQDM T=4, ancestral'),
|
| 48 |
+
'uqdm-f': dict(ls=':x', color='darkorange', label='UQDM T=4, flow-based'),
|
| 49 |
+
'vdm-d': dict(ls='-+', color='blue', label='VDM T=20, denoise', alpha=0.6, lw=1.5),
|
| 50 |
+
'vdm-a': dict(ls='--x', color='blue', label='VDM T=20, ancestral', alpha=0.6, lw=1.5),
|
| 51 |
+
'vdm-f': dict(ls=':x', color='blue', label='VDM T=20, flow-based', alpha=0.6, lw=1.5),
|
| 52 |
+
'vdm-1000d': dict(ls=':+', color='darkturquoise', label='VDM T=1000, denoise', alpha=0.6, lw=1.5),
|
| 53 |
+
'jpeg': dict(ls='-.+', color='red', label='JPEG'),
|
| 54 |
+
'jpeg2000': dict(ls='-x', color='red', label='JPEG2000'),
|
| 55 |
+
'bpg': dict(ls='-x', color='sienna', label='BPG'),
|
| 56 |
+
'ctc': dict(ls='-x', color='fuchsia', label='CTC'),
|
| 57 |
+
'cdc-0': dict(ls='-x', color='green', label='CDC (p=0)'),
|
| 58 |
+
'cdc-p': dict(ls='-.x', color='green', label='CDC (p=0.9)'),
|
| 59 |
+
'vae-b': dict(ls='--x', color='limegreen', label='VAE (Ballé 2018)'),
|
| 60 |
+
'vae-m': dict(ls='-+', color='limegreen', label='VAE (Minnen 2020)'),
|
| 61 |
+
}
|
| 62 |
+
sns.set_style('whitegrid')
|
| 63 |
+
|
| 64 |
+
# Plots
|
| 65 |
+
textwidth = 5.5206 * 2.5
|
| 66 |
+
fig_rd, ax_rd = plt.subplots(figsize=(0.45 * textwidth, 0.36 * textwidth))
|
| 67 |
+
fig_fid, ax_fid = plt.subplots(figsize=(0.45 * textwidth, 0.36 * textwidth))
|
| 68 |
+
for b in baselines:
|
| 69 |
+
bpp, psnr, fid = results[b]['bpp'], results[b]['psnr'], results[b]['fid']
|
| 70 |
+
kwargs = pl_kwargs | pl_styles[b]
|
| 71 |
+
ls = kwargs.pop('ls', None)
|
| 72 |
+
ax_rd.plot(bpp, psnr, ls, **kwargs)
|
| 73 |
+
ax_fid.plot(bpp, fid, ls, **kwargs)
|
| 74 |
+
ax_rd.legend(loc='lower right')
|
| 75 |
+
ax_fid.legend(loc='upper right')
|
| 76 |
+
ax_rd.set(xlabel='Rate (bpp)', ylabel='PSNR (dB)')
|
| 77 |
+
ax_fid.set(xlabel='Rate (bpp)', ylabel='FID')
|
| 78 |
+
ax_rd.grid(visible=True)
|
| 79 |
+
ax_fid.grid(visible=True)
|
| 80 |
+
ax_fid.set_yscale('symlog')
|
| 81 |
+
fig_rd.tight_layout()
|
| 82 |
+
fig_fid.tight_layout()
|
| 83 |
+
fig_rd.savefig('tmp_rd.png', bbox_inches='tight', pad_inches=0, dpi=600)
|
| 84 |
+
fig_fid.savefig('tmp_fid.png', bbox_inches='tight', pad_inches=0, dpi=600)
|
| 85 |
+
plt.show()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == '__main__':
|
| 89 |
+
# Rate-distortion, Rate-Realism
|
| 90 |
+
rd_fid(dataset='cifar')
|
| 91 |
+
rd_fid(dataset='imagenet')
|
| 92 |
+
# Plots from the slides
|
| 93 |
+
# Gaussian vs Uniform
|
| 94 |
+
# rd_fid(dataset='imagenet', baselines=['vdm', 'uqdm'])
|
| 95 |
+
# Traditional Baselines
|
| 96 |
+
# rd_fid(dataset='imagenet', baselines=['jpeg', 'jpeg2000', 'bpg', 'uqdm'])
|
| 97 |
+
# Neural Baselines
|
| 98 |
+
# rd_fid(dataset='imagenet', baselines=['ctc', 'cdc', 'vae', 'uqdm'])
|
| 99 |
+
# Progressive Baselines
|
| 100 |
+
# rd_fid(dataset='imagenet', baselines=['jpeg2000', 'ctc', 'uqdm'])
|