Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import argparse | |
| from dataclasses import dataclass | |
| from arrgh import arrgh | |
| import polyscope as ps | |
| import polyscope.imgui as psim | |
| import potpourri3d as pp3d | |
| import trimesh | |
| import cuml | |
| import xgboost as xgb | |
| import os, random | |
| import sys | |
| sys.path.append("..") | |
| from partfield.utils import * | |
| class State: | |
| objects = None | |
| train_objects = None | |
| # Input options | |
| subsample_inputs: int = -1 | |
| n_train_subset: int = 0 | |
| # Label | |
| N_class: int = 2 | |
| # Annotations | |
| # A annotations (initially A = 0) | |
| anno_feat: np.array = np.zeros((0,448), dtype=np.float32) # [A,F] | |
| anno_label: np.array = np.zeros((0,), dtype=np.int32) # [A] | |
| anno_pos: np.array = np.zeros((0,3), dtype=np.float32) # [A,3] | |
| # Intermediate selection data | |
| is_selecting: bool = False | |
| selection_class: int = 0 | |
| # Fitting algorithm | |
| fit_to: str = "Annotations" | |
| fit_method : str = "LogisticRegression" | |
| auto_update_fit: bool = True | |
| # Training data | |
| # T training datapoints | |
| train_feat: np.array = np.zeros((0,448), dtype=np.float32) # [T,F] | |
| train_label: np.array = np.zeros((0,), dtype=np.int32) # [T] | |
| # Viz | |
| grid_w : int = 8 | |
| per_obj_shift : float = 2. | |
| anno_radius : float = 0.01 | |
| ps_cloud_annotation = None | |
| ps_structure_name_to_index_map = {} | |
| fit_methods_list = ["LinearRegression", "LogisticRegression", "LinearSVC", "RandomForest", "NearestNeighbors", "XGBoost"] | |
| fit_to_list = ["Annotations", "TrainingSet"] | |
| def load_mesh_and_features(mesh_filepath, ind, require_gt=False, gt_label_fol = ""): | |
| dirpath, filename = os.path.split(mesh_filepath) | |
| filename_core = filename[9:-6] # splits off "feat_pca_" ... "_0.ply" | |
| feature_filename = "part_feat_"+ filename_core + "_0_batch.npy" | |
| feature_filepath = os.path.join(dirpath, feature_filename) | |
| gt_filename = filename_core + ".seg" | |
| gt_filepath = os.path.join(gt_label_fol, gt_filename) | |
| have_gt = os.path.isfile(gt_filepath) | |
| print(" Reading file:") | |
| print(f" Mesh filename: {mesh_filepath}") | |
| print(f" Feature filename: {feature_filepath}") | |
| print(f" Ground Truth Label filename: {gt_filepath} -- present = {have_gt}") | |
| # load features | |
| feat = np.load(feature_filepath, allow_pickle=False) | |
| feat = feat.astype(np.float32) | |
| # load mesh things | |
| # TODO replace this with just loading V/F from numpy archive | |
| tm = load_mesh_util(mesh_filepath) | |
| V = np.array(tm.vertices, dtype=np.float32) | |
| F = np.array(tm.faces) | |
| # load ground truth, if available | |
| if have_gt: | |
| gt_labels = np.loadtxt(gt_filepath) | |
| gt_labels = gt_labels.astype(np.int32) - 1 | |
| else: | |
| if require_gt: | |
| raise ValueError("could not find ground-truth file, but it is required") | |
| gt_labels = None | |
| # pca_colors = None | |
| return { | |
| 'nicename' : f"{ind:02d}_{filename_core}", | |
| 'mesh_filepath' : mesh_filepath, | |
| 'feature_filepath' : feature_filepath, | |
| 'V' : V, | |
| 'F' : F, | |
| 'feat_np' : feat, | |
| # 'feat_pt' : torch.tensor(feat, device='cuda'), | |
| 'gt_labels' : gt_labels | |
| } | |
| def shift_for_ind(state : State, ind): | |
| x_ind = ind % state.grid_w | |
| y_ind = ind // state.grid_w | |
| shift = np.array([state.per_obj_shift * x_ind, 0, -state.per_obj_shift * y_ind]) | |
| return shift | |
| def viz_upper_limit(state : State, ind_count): | |
| x_max = min(ind_count, state.grid_w) | |
| y_max = ind_count // state.grid_w | |
| bound = np.array([state.per_obj_shift * x_max, 0, -state.per_obj_shift * y_max]) | |
| return bound | |
| def initialize_object_viz(state : State, obj, index=0): | |
| obj['ps_mesh'] = ps.register_surface_mesh(obj['nicename'], obj['V'], obj['F'], color=(.8, .8, .8)) | |
| shift = shift_for_ind(state, index) | |
| obj['ps_mesh'].translate(shift) | |
| obj['ps_mesh'].set_selection_mode('faces_only') | |
| state.ps_structure_name_to_index_map[obj['nicename']] = index | |
| def update_prediction(state: State): | |
| print("Updating predictions..") | |
| N_anno = state.anno_label.shape[0] | |
| # Quick out if we don't have at least two distinct class labels present | |
| if(state.fit_to == "Annotations" and len(np.unique(state.anno_label)) <= 1): | |
| return state | |
| # Quick out if we don't have | |
| if(state.fit_to == "TrainingSet" and state.train_objects is None): | |
| return state | |
| if state.fit_method == "LinearRegression": | |
| classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LinearRegression(), strategy='ovr') | |
| elif state.fit_method == "LogisticRegression": | |
| classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LogisticRegression(), strategy='ovr') | |
| elif state.fit_method == "LinearSVC": | |
| classifier = cuml.multiclass.MulticlassClassifier(cuml.svm.LinearSVC(), strategy='ovr') | |
| elif state.fit_method == "RandomForest": | |
| classifier = cuml.ensemble.RandomForestClassifier() | |
| elif state.fit_method == "NearestNeighbors": | |
| classifier = cuml.multiclass.MulticlassClassifier(cuml.neighbors.KNeighborsRegressor(n_neighbors=1), strategy='ovr') | |
| elif state.fit_method == "XGBoost": | |
| classifier = xgb.XGBClassifier(max_depth=7, n_estimators=1000) | |
| else: | |
| raise ValueError("unrecognized fit method") | |
| if state.fit_to == "TrainingSet": | |
| all_train_feats = [] | |
| all_train_labels = [] | |
| for obj in state.train_objects: | |
| all_train_feats.append(obj['feat_np']) | |
| all_train_labels.append(obj['gt_labels']) | |
| all_train_feats = np.concatenate(all_train_feats, axis=0) | |
| all_train_labels = np.concatenate(all_train_labels, axis=0) | |
| state.N_class = np.max(all_train_labels) + 1 | |
| classifier.fit(all_train_feats, all_train_labels) | |
| elif state.fit_to == "Annotations": | |
| classifier.fit(state.anno_feat,state.anno_label) | |
| else: | |
| raise ValueError("unrecognized fit to") | |
| n_total = 0 | |
| n_correct = 0 | |
| for obj in state.objects: | |
| obj['pred_label'] = classifier.predict(obj['feat_np']) | |
| if obj['gt_labels'] is not None: | |
| n_total += obj['gt_labels'].shape[0] | |
| n_correct += np.sum(obj['pred_label'] == obj['gt_labels'], dtype=np.int32) | |
| if(state.fit_to == "TrainingSet" and n_total > 0): | |
| frac = n_correct / n_total | |
| print(f"Test accuracy: {n_correct:d} / {n_total:d} {100*frac:.02f}%") | |
| print("Done updating predictions.") | |
| return state | |
| def update_prediction_viz(state: State): | |
| for obj in state.objects: | |
| if 'pred_label' in obj: | |
| obj['ps_mesh'].add_scalar_quantity("pred labels", obj['pred_label'], defined_on='faces', vminmax=(0,state.N_class-1), cmap='turbo', enabled=True) | |
| return state | |
| def update_annotation_viz(state: State): | |
| ps_cloud = ps.register_point_cloud("annotations", state.anno_pos, radius=state.anno_radius, material='candy') | |
| ps_cloud.add_scalar_quantity("labels", state.anno_label, vminmax=(0,state.N_class-1), cmap='turbo', enabled=True) | |
| state.ps_cloud_annotation = ps_cloud | |
| return state | |
| def filter_old_labels(state: State): | |
| """ | |
| Filter out annotations from classes that don't exist any more | |
| """ | |
| keep_mask = state.anno_label < state.N_class | |
| state.anno_feat = state.anno_feat[keep_mask,:] | |
| state.anno_label = state.anno_label[keep_mask] | |
| state.anno_pos = state.anno_pos[keep_mask,:] | |
| return state | |
| def undo_last_annotation(state: State): | |
| state.anno_feat = state.anno_feat[:-1,:] | |
| state.anno_label = state.anno_label[:-1] | |
| state.anno_pos = state.anno_pos[:-1,:] | |
| return state | |
| def ps_callback(state_list): | |
| state : State = state_list[0] # hacky pass-by-reference, since we want to edit it below | |
| # If we're in selection mode, that's the only thing we can do | |
| if state.is_selecting: | |
| psim.TextUnformatted(f"Annotating class {state.selection_class:02d}. Click on any mesh face.") | |
| io = psim.GetIO() | |
| if io.MouseClicked[0]: | |
| screen_coords = io.MousePos | |
| pick_result = ps.pick(screen_coords=screen_coords) | |
| # Check if we hit one of the meshes | |
| if pick_result.is_hit and pick_result.structure_name in state.ps_structure_name_to_index_map: | |
| if pick_result.structure_data['element_type'] != "face": | |
| # shouldn't be possible | |
| raise ValueError("pick returned non-face") | |
| i_obj = state.ps_structure_name_to_index_map[pick_result.structure_name] | |
| f_hit = pick_result.structure_data['index'] | |
| obj = state.objects[i_obj] | |
| V = obj['V'] | |
| F = obj['F'] | |
| feat = obj['feat_np'] | |
| face_corners = V[F[f_hit,:],:] | |
| new_anno_feat = feat[f_hit,:] | |
| new_anno_label = state.selection_class | |
| new_anno_pos = np.mean(face_corners, axis=0) + shift_for_ind(state, i_obj) | |
| state.anno_feat = np.concatenate((state.anno_feat, new_anno_feat[None,:])) | |
| state.anno_label = np.concatenate((state.anno_label, np.array((new_anno_label,)))) | |
| state.anno_pos = np.concatenate((state.anno_pos, new_anno_pos[None,:])) | |
| state = update_annotation_viz(state) | |
| state.is_selecting = False | |
| needs_pred_update = True | |
| if state.auto_update_fit: | |
| state = update_prediction(state) | |
| state = update_prediction_viz(state) | |
| return | |
| # If not selecting, build the main UI | |
| needs_pred_update = False | |
| psim.PushItemWidth(150) | |
| changed, state.N_class = psim.InputInt("N_class", state.N_class, step=1) | |
| psim.PopItemWidth() | |
| if changed: | |
| state = filter_old_labels(state) | |
| state = update_annotation_viz(state) | |
| # Check for keypress annotation | |
| io = psim.GetIO() | |
| class_keys = { 'w' : 0, '1' : 1, '2' : 2, '3' : 3, '4' : 4, '5' : 5, '6' : 6, '7' : 7, '8' : 8, '9' : 9,} | |
| for c in class_keys: | |
| if class_keys[c] >= state.N_class: | |
| continue | |
| if psim.IsKeyPressed(ps.get_key_code(c)): | |
| state.is_selecting = True | |
| state.selection_class = class_keys[c] | |
| psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) | |
| if(psim.TreeNode("Annotate")): | |
| psim.TextUnformatted("New class annotation. Select class to add add annotation for:") | |
| psim.TextUnformatted("(alternately, press key {w,1,2,3,4...})") | |
| for i_class in range(state.N_class): | |
| if i_class > 0: | |
| psim.SameLine() | |
| if psim.Button(f"{i_class:02d}"): | |
| state.is_selecting = True | |
| state.selection_class = i_class | |
| if psim.Button("Undo Last Annotation"): | |
| state = undo_last_annotation(state) | |
| state = update_annotation_viz(state) | |
| needs_pred_update = True | |
| psim.TreePop() | |
| psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) | |
| if(psim.TreeNode("Fit")): | |
| psim.PushItemWidth(150) | |
| changed, ind = psim.Combo("Fit To", fit_to_list.index(state.fit_to), fit_to_list) | |
| if changed: | |
| state.fit_to = fit_methods_list[ind] | |
| needs_pred_update = True | |
| changed, ind = psim.Combo("Fit Method", fit_methods_list.index(state.fit_method), fit_methods_list) | |
| if changed: | |
| state.fit_method = fit_methods_list[ind] | |
| needs_pred_update = True | |
| if psim.Button("Update fit"): | |
| state = update_prediction(state) | |
| state = update_prediction_viz(state) | |
| psim.SameLine() | |
| changed, state.auto_update_fit = psim.Checkbox("Auto-update fit", state.auto_update_fit) | |
| if changed: | |
| needs_pred_update = True | |
| psim.PopItemWidth() | |
| psim.TreePop() | |
| psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) | |
| if(psim.TreeNode("Visualization")): | |
| psim.PushItemWidth(150) | |
| changed, state.anno_radius = psim.SliderFloat("Annotation Point Radius", state.anno_radius, 0.00001, 0.02) | |
| if changed: | |
| state = update_annotation_viz(state) | |
| psim.PopItemWidth() | |
| psim.TreePop() | |
| if needs_pred_update and state.auto_update_fit: | |
| state = update_prediction(state) | |
| state = update_prediction_viz(state) | |
| def main(): | |
| state = State() | |
| ## Parse args | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--meshes', nargs='+', help='List of meshes to process.', required=True) | |
| parser.add_argument('--n_train_subset', default=0, help='How many meshes to train on.') | |
| parser.add_argument('--gt_label_fol', default="../data/coseg_guitar/gt", help='Path where labels are stored.') | |
| parser.add_argument('--subsample_inputs', default=state.subsample_inputs, help='Only show a random fraction of inputs') | |
| parser.add_argument('--per_obj_shift', default=state.per_obj_shift, help='How to space out objects in UI grid') | |
| parser.add_argument('--grid_w', default=state.grid_w, help='Grid width') | |
| args = parser.parse_args() | |
| state.n_train_subset = int(args.n_train_subset) | |
| state.subsample_inputs = int(args.subsample_inputs) | |
| state.per_obj_shift = float(args.per_obj_shift) | |
| state.grid_w = int(args.grid_w) | |
| ## Load data | |
| # First, resolve directories to load all files in directory | |
| all_filepaths = [] | |
| print("Resolving passed directories") | |
| for entry in args.meshes: | |
| if os.path.isdir(entry): | |
| dir_path = entry | |
| print(f" processing directory {dir_path}") | |
| for filename in os.listdir(dir_path): | |
| file_path = os.path.join(dir_path, filename) | |
| if os.path.isfile(file_path) and file_path.endswith(".ply") and "feat_pca" in file_path: | |
| print(f" adding file {file_path}") | |
| all_filepaths.append(file_path) | |
| else: | |
| all_filepaths.append(entry) | |
| random.shuffle(all_filepaths) | |
| if state.subsample_inputs != -1: | |
| all_filepaths = all_filepaths[:state.subsample_inputs] | |
| if state.n_train_subset != 0: | |
| print(state.n_train_subset) | |
| train_filepaths = all_filepaths[:state.n_train_subset] | |
| all_filepaths = all_filepaths[state.n_train_subset:] | |
| print(f"Loading {len(train_filepaths)} files") | |
| state.train_objects = [] | |
| for i, file_path in enumerate(train_filepaths): | |
| state.train_objects.append(load_mesh_and_features(file_path, i, require_gt=True, gt_label_fol=args.gt_label_fol)) | |
| state.fit_to = "TrainingSet" | |
| # Load files | |
| print(f"Loading {len(all_filepaths)} files") | |
| state.objects = [] | |
| for i, file_path in enumerate(all_filepaths): | |
| state.objects.append(load_mesh_and_features(file_path, i)) | |
| ## Set up visualization | |
| ps.init() | |
| ps.set_automatically_compute_scene_extents(False) | |
| lim = viz_upper_limit(state, len(state.objects)) | |
| ps.set_length_scale(np.linalg.norm(lim) / 4.) | |
| low = np.array((0, -1., -1.)) | |
| high = lim | |
| ps.set_bounding_box(low, high) | |
| for ind, o in enumerate(state.objects): | |
| initialize_object_viz(state, o, ind) | |
| print(f"Loaded {len(state.objects)} objects") | |
| if state.n_train_subset != 0: | |
| print(f"Loaded {len(state.train_objects)} training objects") | |
| # One first prediction | |
| # (does nothing if there is no annotatoins / training data) | |
| state = update_prediction(state) | |
| state = update_prediction_viz(state) | |
| # Start the interactive UI | |
| ps.set_user_callback(lambda : ps_callback([state])) | |
| ps.show() | |
| if __name__ == "__main__": | |
| main() | |