| import unittest | |
| import torch | |
| import numpy as np | |
| from tensorboardX import SummaryWriter | |
| from tensorboardX import summary | |
| from .expect_reader import compare_proto | |
| np.random.seed(0) | |
| true_positive_counts = [75, 64, 21, 5, 0] | |
| false_positive_counts = [150, 105, 18, 0, 0] | |
| true_negative_counts = [0, 45, 132, 150, 150] | |
| false_negative_counts = [0, 11, 54, 70, 75] | |
| precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] | |
| recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] | |
| class PRCurveTest(unittest.TestCase): | |
| def test_smoke(self): | |
| with SummaryWriter() as writer: | |
| writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand( | |
| 100), 1) | |
| writer.add_pr_curve_raw('prcurve with raw data', | |
| true_positive_counts, | |
| false_positive_counts, | |
| true_negative_counts, | |
| false_negative_counts, | |
| precision, | |
| recall, | |
| 1) | |
| def test_pr_purve(self): | |
| random_labels = np.array([0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, | |
| 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, | |
| 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, | |
| 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, | |
| 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]) | |
| random_probs = np.array([0.33327776, 0.30032885, 0.79012837, 0.04306813, 0.65221544, | |
| 0.58481968, 0.28305522, 0.53795795, 0.00729739, 0.52266951, | |
| 0.22464247, 0.11262435, 0.41573075, 0.92493992, 0.73066758, | |
| 0.43867735, 0.27955449, 0.56975382, 0.53933028, 0.34392824, | |
| 0.30312509, 0.81732807, 0.55408544, 0.3969487 , 0.31768033, | |
| 0.24353266, 0.47198005, 0.19999122, 0.05788022, 0.24046305, | |
| 0.04651082, 0.30061738, 0.78321545, 0.82670207, 0.49200517, | |
| 0.80904619, 0.96711993, 0.3160946 , 0.01049424, 0.60108337, | |
| 0.56508792, 0.83729429, 0.9717386 , 0.46306053, 0.80232138, | |
| 0.24166823, 0.7393237 , 0.50820418, 0.04944932, 0.53854157, | |
| 0.10765172, 0.84723855, 0.20518299, 0.3143431 , 0.51299074, | |
| 0.47065695, 0.54267833, 0.1812676 , 0.06265177, 0.34110327, | |
| 0.30915171, 0.91870169, 0.91309447, 0.31395817, 0.36780571, | |
| 0.98297986, 0.00594547, 0.52839042, 0.70229202, 0.37779588, | |
| 0.15207045, 0.59759632, 0.72397032, 0.71502195, 0.90135725, | |
| 0.43970107, 0.17123532, 0.08785938, 0.04986818, 0.62702444, | |
| 0.69171023, 0.30537792, 0.30285433, 0.27124347, 0.27693729, | |
| 0.7136039 , 0.48022489, 0.20916285, 0.2018599 , 0.92401008, | |
| 0.30189681, 0.46862626, 0.96353024, 0.30468533, 0.68281294, | |
| 0.30623562, 0.40795975, 0.76824531, 0.89824215, 0.69845035], dtype=np.float16) | |
| compare_proto(summary.pr_curve('tag', random_labels, random_probs, 1), self) | |
| def test_pr_purve_raw(self): | |
| compare_proto(summary.pr_curve_raw('prcurve with raw data', | |
| true_positive_counts, | |
| false_positive_counts, | |
| true_negative_counts, | |
| false_negative_counts, | |
| precision, | |
| recall, | |
| 1), | |
| self) | |