File size: 4,095 Bytes
382733a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neighbors import KDTree

import torch
import torch.nn.functional as F

def get_voxel_partfeats(voxel_coords, part_planes):
    voxel_coords = ((voxel_coords[:, 1:] + 0.5) / 64 - 0.5).cpu().numpy()
    bbmin = voxel_coords.min(0)
    bbmax = voxel_coords.max(0)
    center = (bbmin + bbmax) * 0.5
    scale = 2.0 * 0.9 / (bbmax - bbmin).max()
    voxel_coords = (voxel_coords - center) * scale

    tensor_vertices = torch.from_numpy(voxel_coords).unsqueeze(0).reshape(1, -1, 3).cuda().to(torch.float16)
    part_feats = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
    part_feats = part_feats.cpu().numpy().reshape(-1, 448)
    
    return part_feats

def sample_triplane_feat(feature_triplane, normalized_pos):
    '''
        normalized_pos [-1, 1]
    '''
    tri_plane = torch.unbind(feature_triplane, dim=1)

    x_feat = F.grid_sample(
        tri_plane[0],
        torch.cat(
            [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]],
            dim=-1).unsqueeze(dim=1), padding_mode='border',
        align_corners=True)
    y_feat = F.grid_sample(
        tri_plane[1],
        torch.cat(
            [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]],
            dim=-1).unsqueeze(dim=1), padding_mode='border',
        align_corners=True)

    z_feat = F.grid_sample(
        tri_plane[2],
        torch.cat(
            [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]],
            dim=-1).unsqueeze(dim=1), padding_mode='border',
        align_corners=True)
    final_feat = (x_feat + y_feat + z_feat)
    final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1)  # 32dimension
    return final_feat

def cosegment_part(app_coords, app_part_planes, struct_coords, struct_part_planes, num_clusters=30):
    struct_partfield_feats = get_voxel_partfeats(struct_coords, struct_part_planes)
    app_partfield_feats = get_voxel_partfeats(app_coords, app_part_planes)
    
    point_feat1 = app_partfield_feats
    point_feat2 = struct_partfield_feats
    
    point_feat1 = point_feat1 / np.linalg.norm(point_feat1, axis=-1, keepdims=True)
    point_feat2 = point_feat2 / np.linalg.norm(point_feat2, axis=-1, keepdims=True)
    
    clustering1 = KMeans(n_clusters=num_clusters, random_state=0, n_init="auto").fit(point_feat1)
    # Get feature means per cluster
    feature_means1 = []
    for j in range(num_clusters):
        all_cluster_feat = point_feat1[clustering1.labels_==j]
        mean_feat = np.mean(all_cluster_feat, axis=0)
        feature_means1.append(mean_feat)
    
    labels1 = clustering1.labels_

    feature_means1 = np.array(feature_means1)
    tree = KDTree(feature_means1)
    
    init_mode = np.array(feature_means1)
    
    point_feat2 = point_feat2 / np.linalg.norm(point_feat2, axis=-1, keepdims=True)
    clustering2 = KMeans(n_clusters=num_clusters, random_state=0, init=init_mode).fit(point_feat2)

    ### Get feature means per cluster
    feature_means2 = []
    for j in range(num_clusters):
        all_cluster_feat = point_feat2[clustering2.labels_==j]
        mean_feat = np.mean(all_cluster_feat, axis=0)
        feature_means2.append(mean_feat)

    feature_means2 = np.array(feature_means2)
    _, nn_idx = tree.query(feature_means2, k=1)
    relabelled_2 = nn_idx[clustering2.labels_]
    
    return labels1, relabelled_2, point_feat1, point_feat2

def cluster_geoms(struct_coords, struct_part_planes, num_clusters=10):
    struct_partfield_feats = get_voxel_partfeats(struct_coords, struct_part_planes)
    
    point_feat = struct_partfield_feats
    point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
    
    
    clustering = KMeans(n_clusters=num_clusters, random_state=0, n_init="auto").fit(point_feat)
    # Get feature means per cluster
    feature_means = []
    for j in range(num_clusters):
        all_cluster_feat = point_feat[clustering.labels_==j]
        mean_feat = np.mean(all_cluster_feat, axis=0)
        feature_means.append(mean_feat)
    
    labels = clustering.labels_
    return labels