| import numpy as np |
| import json |
| import os |
| import itertools |
| import trimesh |
| from matplotlib.path import Path |
| from collections import Counter |
| from sklearn.neighbors import KNeighborsClassifier |
|
|
|
|
| def load_segmentation(path, shape): |
| """ |
| Get a segmentation mask for a given image |
| Arguments: |
| path: path to the segmentation json file |
| shape: shape of the output mask |
| Returns: |
| Returns a segmentation mask |
| """ |
| with open(path) as json_file: |
| dict = json.load(json_file) |
| segmentations = [] |
| for key, val in dict.items(): |
| if not key.startswith('item'): |
| continue |
|
|
| |
| |
| |
|
|
| coordinates = [] |
| for segmentation_coord in val['segmentation']: |
| |
| x = segmentation_coord[::2] |
| y = segmentation_coord[1::2] |
| xy = np.vstack((x, y)).T |
| coordinates.append(xy) |
|
|
| segmentations.append({ |
| 'type': val['category_name'], |
| 'type_id': val['category_id'], |
| 'coordinates': coordinates |
| }) |
|
|
| return segmentations |
|
|
|
|
| def smpl_to_recon_labels(recon, smpl, k=1): |
| """ |
| Get the bodypart labels for the recon object by using the labels from the corresponding smpl object |
| Arguments: |
| recon: trimesh object (fully clothed model) |
| shape: trimesh object (smpl model) |
| k: number of nearest neighbours to use |
| Returns: |
| Returns a dictionary containing the bodypart and the corresponding indices |
| """ |
| smpl_vert_segmentation = json.load( |
| open( |
| os.path.join(os.path.dirname(__file__), |
| 'smpl_vert_segmentation.json'))) |
| n = smpl.vertices.shape[0] |
| y = np.array([None] * n) |
| for key, val in smpl_vert_segmentation.items(): |
| y[val] = key |
|
|
| classifier = KNeighborsClassifier(n_neighbors=1) |
| classifier.fit(smpl.vertices, y) |
|
|
| y_pred = classifier.predict(recon.vertices) |
|
|
| recon_labels = {} |
| for key in smpl_vert_segmentation.keys(): |
| recon_labels[key] = list( |
| np.argwhere(y_pred == key).flatten().astype(int)) |
|
|
| return recon_labels |
|
|
|
|
| def extract_cloth(recon, segmentation, K, R, t, smpl=None): |
| """ |
| Extract a portion of a mesh using 2d segmentation coordinates |
| Arguments: |
| recon: fully clothed mesh |
| seg_coord: segmentation coordinates in 2D (NDC) |
| K: intrinsic matrix of the projection |
| R: rotation matrix of the projection |
| t: translation vector of the projection |
| Returns: |
| Returns a submesh using the segmentation coordinates |
| """ |
| seg_coord = segmentation['coord_normalized'] |
| mesh = trimesh.Trimesh(recon.vertices, recon.faces) |
| extrinsic = np.zeros((3, 4)) |
| extrinsic[:3, :3] = R |
| extrinsic[:, 3] = t |
| P = K[:3, :3] @ extrinsic |
|
|
| P_inv = np.linalg.pinv(P) |
|
|
| |
| |
| points_so_far = [] |
| faces = recon.faces |
| for polygon in seg_coord: |
| n = len(polygon) |
| coords_h = np.hstack((polygon, np.ones((n, 1)))) |
| |
| XYZ = P_inv @ coords_h[:, :, None] |
| XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1])) |
| XYZ = XYZ[:, :3] / XYZ[:, 3, None] |
|
|
| p = Path(XYZ[:, :2]) |
|
|
| grid = p.contains_points(recon.vertices[:, :2]) |
| indeces = np.argwhere(grid == True) |
| points_so_far += list(indeces.flatten()) |
|
|
| if smpl is not None: |
| num_verts = recon.vertices.shape[0] |
| recon_labels = smpl_to_recon_labels(recon, smpl) |
| body_parts_to_remove = [ |
| 'rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head', |
| 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', |
| 'rightHand' |
| ] |
| type = segmentation['type_id'] |
|
|
| |
| |
| |
| if type == 1 or type == 3 or type == 10: |
| body_parts_to_remove += ['leftForeArm', 'rightForeArm'] |
| |
| elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9: |
| body_parts_to_remove += [ |
| 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm' |
| ] |
| |
| elif type == 7: |
| body_parts_to_remove += [ |
| 'leftLeg', 'rightLeg', 'leftForeArm', 'rightForeArm', |
| 'leftArm', 'rightArm' |
| ] |
|
|
| verts_to_remove = list( |
| itertools.chain.from_iterable( |
| [recon_labels[part] for part in body_parts_to_remove])) |
|
|
| label_mask = np.zeros(num_verts, dtype=bool) |
| label_mask[verts_to_remove] = True |
|
|
| seg_mask = np.zeros(num_verts, dtype=bool) |
| seg_mask[points_so_far] = True |
|
|
| |
| |
| extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask)) |
|
|
| combine_mask = np.zeros(num_verts, dtype=bool) |
| combine_mask[points_so_far] = True |
| combine_mask[extra_verts_to_remove] = False |
|
|
| all_indices = np.argwhere(combine_mask == True).flatten() |
|
|
| i_x = np.where(np.in1d(faces[:, 0], all_indices))[0] |
| i_y = np.where(np.in1d(faces[:, 1], all_indices))[0] |
| i_z = np.where(np.in1d(faces[:, 2], all_indices))[0] |
|
|
| faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z))) |
| mask = np.zeros(len(recon.faces), dtype=bool) |
| if len(faces_to_keep) > 0: |
| mask[faces_to_keep] = True |
|
|
| mesh.update_faces(mask) |
| mesh.remove_unreferenced_vertices() |
|
|
| |
|
|
| return mesh |
|
|
| return None |
|
|