File size: 1,659 Bytes
5db43ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np

from util.json_writer import json_wrtiter


class ErrorMeasure:
    def __init__(self):
        self.err_list = []
        self.err0_list = []
        self.err1_list = []
        self.err2_list = []
        self.err3_list = []
        self.err4_list = []
        self.err5_list = []

    def append(self, pred, truth):
        if pred.dim() == 1:
            pred = pred.unsqueeze(0)
        if truth.dim() == 1:
            truth = truth.unsqueeze(0)
        self.err_list.append(self.get_err(pred, truth))
        self.err0_list.append(self.get_err(pred[:, 0], truth[:, 0]))
        self.err1_list.append(self.get_err(pred[:, 1], truth[:, 1]))
        self.err2_list.append(self.get_err(pred[:, 2], truth[:, 2]))
        self.err3_list.append(self.get_err(pred[:, 3], truth[:, 3]))
        self.err4_list.append(self.get_err(pred[:, 4], truth[:, 4]))
        self.err5_list.append(self.get_err(pred[:, 5], truth[:, 5]))

    def get_err(self, a, b):
        return (a - b).detach().cpu().square().mean().item()

    def save_to_json(self, path):
        # rotational angle,  left arm roll, right arm roll, left arm pitch, right arm pitch, and body size
        data_dict = {
            "total_error": np.mean(self.err_list),
            "rotational_angle_error": np.mean(self.err0_list),
            "left_arm_roll_error": np.mean(self.err1_list),
            "right_arm_roll_error": np.mean(self.err2_list),
            "left_arm_pitch_error": np.mean(self.err3_list),
            "right_arm_pitch_error": np.mean(self.err4_list),
            "body_size_error": np.mean(self.err5_list),
        }
        json_wrtiter(data_dict, path)