File size: 7,538 Bytes
4724018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import numpy as np
import h5py
import torch
import torch.nn.functional as F
import random
import json
from tqdm import tqdm

def chamfer_distance(points_pred, points_gt):
    x, y = points_pred, points_gt
    bs, num_points, points_dim = x.size()
    xx = torch.bmm(x, x.transpose(2, 1))
    yy = torch.bmm(y, y.transpose(2, 1))
    zz = torch.bmm(x, y.transpose(2, 1))
    diag_ind = torch.arange(0, num_points).to(points_pred).long()
    rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
    ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
    P = rx.transpose(2, 1) + ry - 2 * zz
    return (P.min(1)[0].mean(dim=1) + P.min(2)[0].mean(dim=1))

def chamfer_distance1(pc1, pc2):
    # pairwise_dist: (N1, N2)
    pairwise_dist = torch.cdist(pc1, pc2, p=2) ** 2 # Euclidean distances
    
    min_dist_pc1_to_pc2 = pairwise_dist.min(dim=1)[0]  # shape (N1,)
    min_dist_pc2_to_pc1 = pairwise_dist.min(dim=0)[0]  # shape (N2,)
    
    chamfer = min_dist_pc1_to_pc2.mean() + min_dist_pc2_to_pc1.mean()
    return chamfer

def points_to_voxel_grid(points, voxel_size, grid_min, grid_max):
    """Converts a point cloud to a voxel grid representation.

    Args:
        points (torch.Tensor): (N, 3) tensor of point coordinates.
        voxel_size (float): Size of each voxel.
        grid_min (torch.Tensor): (3,) tensor of minimum grid coordinates.
        grid_max (torch.Tensor): (3,) tensor of maximum grid coordinates.

    Returns:
        torch.Tensor: (Dx, Dy, Dz) boolean tensor representing the voxel grid.
    """
    device = points.device
    grid_size = ((grid_max - grid_min) / voxel_size).round().int()
    grid = torch.zeros(*grid_size, dtype=torch.bool, device=device)

    valid_indices = (
        (points >= grid_min).all(dim=1) & (points < grid_max).all(dim=1)
    )
    valid_points = points[valid_indices]
    voxel_indices = ((valid_points - grid_min) / voxel_size).floor().long()
    # print(grid_size, voxel_indices.max(), valid_points.max())
    # import pdb; pdb.set_trace();

    grid[voxel_indices[:, 0], voxel_indices[:, 1], voxel_indices[:, 2]] = True
    return grid

def volume_iou(grid1, grid2):
    """Calculates the Volume IoU between two voxel grids.
    Args:
        grid1 (torch.Tensor): (Dx, Dy, Dz) boolean tensor representing the first voxel grid.
        grid2 (torch.Tensor): (Dx, Dy, Dz) boolean tensor representing the second voxel grid.
    Returns:
        float: Volume IoU score.
    """
    intersection = (grid1 & grid2).sum().float()
    union = (grid1 | grid2).sum().float()
    if union == 0:
        return 0.0
    return intersection / union

def evaluate_4d(pts_pred, pts_gt):
    N = pts_pred.shape[0]
    total_iou, total_cd = 0.0, 0.0

    voxel_size = 0.1
    grid_min = torch.tensor([-1.5, -1.5, -1.5])
    grid_max = torch.tensor([1.5, 1.5, 1.5])

    for i in range(N):
        cd = chamfer_distance1(pts_pred[i], pts_gt[i])
        total_cd += cd

        grid1 = points_to_voxel_grid(pts_gt[i].reshape(-1, 3), voxel_size, grid_min, grid_max)
        grid2 = points_to_voxel_grid(pts_pred[i].reshape(-1, 3), voxel_size, grid_min, grid_max)
        iou = volume_iou(grid1, grid2)
        total_iou += iou
    total_iou /= N
    total_cd /= N
    total_mse = F.mse_loss(pts_pred, pts_gt)
    return total_iou, total_cd, total_mse

def evaluate_test_4d(pts_pred_all, pts_gt_all):
    N = pts_pred_all.shape[0]
    print(f"Num of testing samples: {N}")
    total_iou, total_cd, total_mse = 0., 0., 0.
    single_iou, single_cd, single_mse = [], [], []
    for i in tqdm(range(N)):
        seq_iou, seq_cd, seq_mse = evaluate_4d(pts_pred_all[i], pts_gt_all[i])
        single_iou.append(seq_iou)
        single_cd.append(seq_cd)
        single_mse.append(seq_mse)
        total_iou += seq_iou
        total_cd += seq_cd
        total_mse += seq_mse
    return total_iou / N, total_cd / N, total_mse / N, single_iou, single_cd, single_mse

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--split_lst', type=str, default='./outputs_v3_test_list.json')
    parser.add_argument('--pred_path', type=str, default='./outputs/dit_8layers_2048p_hf-objaverse-v1_24frames_pointembed_latent256_deform0.001_8gpus_v5/test_100_25steps')
    parser.add_argument('--save_html', action='store_true')
    parser.add_argument('--num_samples', type=int, default=100)
    args = parser.parse_args()

    # split_lst = json.load(open('/mnt/kostas-graid/datasets/chenwang/traj/ObjaverseXL_sketchfab/raw/hf-objaverse-v1/hf-objaverse-v1_valid_sim_list_v3_test.json'))
    split_lst = json.load(open(args.split_lst))
    random.seed(0)
    random.shuffle(split_lst)
    split_lst = split_lst[:100]
    print(split_lst)

    pts_gt_all = []
    pts_pred_all = []
    gif_paths = []

    import glob, os
    for i in range(len(split_lst)):
        model_metas = h5py.File(f'/mnt/kostas-graid/datasets/chenwang/traj/ObjaverseXL_sketchfab/raw/hf-objaverse-v1/outputs_v3/{split_lst[i]}')
        pts_gt = np.array(model_metas['x'])[1:48:2]
        pts_gt = (pts_gt - 5) / 2
        pts_pred = np.load(f"{args.pred_path}/{i}_{torch.log10(torch.from_numpy(np.array(model_metas['E'])).float()):03f}_{np.array(model_metas['nu']):03f}.npy")

        pts_gt_all.append(pts_gt)
        pts_pred_all.append(pts_pred)
    pts_gt_all = np.stack(pts_gt_all, axis=0)
    pts_pred_all = np.stack(pts_pred_all, axis=0)
    print("Loaded all test samples")
    iou, cd, mse, single_iou, single_cd, single_mse = evaluate_test_4d(torch.from_numpy(pts_gt_all), torch.from_numpy(pts_pred_all))
    print("IOU, CD, MSE:", iou, cd, mse)

    if args.save_html:
        sorted_indices = np.argsort(single_iou)
        pts_gt_all = pts_gt_all[sorted_indices]
        pts_pred_all = pts_pred_all[sorted_indices]
        single_iou = np.array(single_iou)[sorted_indices]
        single_cd = np.array(single_cd)[sorted_indices]
        single_mse = np.array(single_mse)[sorted_indices]
        gif_paths = [gif_paths[i] for i in sorted_indices]

        import html
        rows = [
            "<!DOCTYPE html>",
            "<html lang='en'>",
            "<head>",
            "  <meta charset='utf-8'>",
            "  <title>GIF gallery</title>",
            "  <style>",
            "     body{margin:0;font-family:sans-serif;background:#fafafa;color:#333}",
            "     .row{padding:16px;text-align:center;border-bottom:1px solid #eee;}",
            "     img{max-width:50%;height:auto;display:block;margin:0 auto;}",
            "     .caption{margin-top:8px;font-size:0.9rem;word-break:break-all;}",
            "  </style>",
            "</head>",
            "<body>",
        ]

        # 4) one <div> per gif with caption
        for i, gif in enumerate(gif_paths):
            name = gif                      # full file name (incl. .gif)
            alt  = html.escape(gif)         # alt text sans extension
            rows.append(
                f"  <div class='row'>"
                f"<img src='{name.split('/')[-1]}' alt='{alt}'>"
                f"<p class='caption'>"
                f"Name: {html.escape(name)}<br>"
                f"IoU: {single_iou[i]:.4f}, Chamfer Distance: {single_cd[i]:.4f}, MSE: {single_mse[i]:.4f}"
                f"</p>"
                f"</div>"
            )

        rows += ["</body>", "</html>"]
        with open('./outputs/dit_8layers_2048p_hf-objaverse-v1_24frames_pointembed_latent256_deform0.001_8gpus_v3_all/test_1000/visualize.html', 'w') as f:
            f.write('\n'.join(rows))