| import multiprocessing
|
|
|
| import networkx as nx
|
| import numpy as np
|
| import argparse
|
| import os
|
|
|
| import trimesh
|
| from tqdm import tqdm
|
| import ray
|
|
|
| from check_valid import check_step_valid_soild, load_data_with_prefix
|
| from eval_brepgen import normalize_pc
|
| from eval_unique_novel import *
|
|
|
|
|
| def find_connected_components(matrix):
|
| N = len(matrix)
|
| visited = [False] * N
|
| components = []
|
|
|
| def dfs(idx, component):
|
| stack = [idx]
|
| while stack:
|
| node = stack.pop()
|
| if not visited[node]:
|
| visited[node] = True
|
| component.append(node)
|
| for neighbor in range(N):
|
| if matrix[node][neighbor] and not visited[neighbor]:
|
| stack.append(neighbor)
|
|
|
| for i in range(N):
|
| if not visited[i]:
|
| component = []
|
| dfs(i, component)
|
| components.append(component)
|
|
|
| return components
|
|
|
|
|
| def compute_unique(graph_list, atol=None, is_use_ray=False, batch_size=100000, num_max_split_batch=128):
|
| N = len(graph_list)
|
| identical_pairs = []
|
| unique_graph_idx = list(range(N))
|
| pair_0, pair_1 = np.triu_indices(N, k=1)
|
| check_pairs = np.column_stack((pair_0, pair_1))
|
|
|
| num_split_batch = len(check_pairs) // batch_size
|
| if num_split_batch > 64:
|
| num_split_batch = num_max_split_batch
|
| batch_size = len(check_pairs) // num_split_batch
|
|
|
| if not is_use_ray:
|
| for idx1, idx2 in tqdm(check_pairs):
|
| is_identical = is_graph_identical(graph_list[idx1], graph_list[idx2], atol=atol)
|
| if is_identical:
|
| unique_graph_idx.remove(idx2) if idx2 in unique_graph_idx else None
|
| else:
|
| N_batch = len(check_pairs) // batch_size
|
| futures = []
|
| for i in tqdm(range(N_batch)):
|
| batch_pairs = check_pairs[i * batch_size: (i + 1) * batch_size]
|
| batch_graph_pair = [(graph_list[idx1], graph_list[idx2]) for idx1, idx2 in batch_pairs]
|
| futures.append(is_graph_identical_remote.remote(batch_graph_pair, atol))
|
| results = ray.get(futures)
|
|
|
| for batch_idx in tqdm(range(N_batch)):
|
| for idx, is_identical in enumerate(results[batch_idx]):
|
| if not is_identical:
|
| continue
|
| idx1, idx2 = check_pairs[batch_idx * batch_size + idx]
|
| if idx2 in unique_graph_idx:
|
| unique_graph_idx.remove(idx2)
|
| identical_pairs.append((idx1, idx2))
|
|
|
| return unique_graph_idx, identical_pairs
|
|
|
|
|
| def test_check():
|
| sample = np.random.rand(3, 32, 32, 3)
|
| face1 = sample[[0, 1, 2]]
|
| face2 = sample[[0, 2, 1]]
|
| faces_adj1 = [[0, 1]]
|
| faces_adj2 = [[0, 2]]
|
|
|
| graph1 = build_graph(face1, faces_adj1)
|
| graph2 = build_graph(face2, faces_adj2)
|
|
|
| is_identical = is_graph_identical(graph1, graph2)
|
|
|
| print("Graphs are equal" if is_identical else "Graphs are not equal")
|
|
|
|
|
| def load_data_from_npz(data_npz_file):
|
| data_npz = np.load(data_npz_file, allow_pickle=True)
|
| data_npz1 = np.load(data_npz_file.replace("deepcad_32", "deepcad_train_v6"), allow_pickle=True)
|
|
|
| if 'face_edge_adj' in data_npz:
|
| faces = data_npz['pred_face']
|
| face_edge_adj = data_npz['face_edge_adj']
|
| faces_adj_pair = []
|
| N = face_edge_adj.shape[0]
|
| for face_idx1 in range(N):
|
| for face_idx2 in range(face_idx1 + 1, N):
|
| face_edges1 = face_edge_adj[face_idx1]
|
| face_edges2 = face_edge_adj[face_idx2]
|
| if sorted((face_idx1, face_idx2)) in faces_adj_pair:
|
| continue
|
| if len(set(face_edges1).intersection(set(face_edges2))) > 0:
|
| faces_adj_pair.append(sorted((face_idx1, face_idx2)))
|
| return faces, faces_adj_pair
|
|
|
| if 'sample_points_faces' in data_npz:
|
| face_points = data_npz['sample_points_faces']
|
| edge_face_connectivity = data_npz['edge_face_connectivity']
|
| elif 'pred_face' in data_npz and 'pred_edge_face_connectivity' in data_npz:
|
| face_points = data_npz['pred_face']
|
| edge_face_connectivity = data_npz['pred_edge_face_connectivity']
|
| else:
|
| raise ValueError("Invalid data format")
|
| faces_adj_pair = []
|
| for edge_idx, face_idx1, face_idx2 in edge_face_connectivity:
|
| faces_adj_pair.append([face_idx1, face_idx2])
|
| if face_points.shape[-1] != 3:
|
| face_points = face_points[..., :3]
|
|
|
| src_shape = face_points.shape
|
| face_points = normalize_pc(face_points.reshape(-1, 3)).reshape(src_shape)
|
| return face_points, faces_adj_pair
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--train_root", type=str, required=True)
|
| parser.add_argument("--n_bit", type=int)
|
| parser.add_argument("--atol", type=float)
|
| parser.add_argument("--use_ray", action='store_true')
|
| parser.add_argument("--load_batch_size", type=int, default=100)
|
| parser.add_argument("--compute_batch_size", type=int, default=10000)
|
| parser.add_argument("--txt", type=str, default=None)
|
| parser.add_argument("--num_cpus", type=int, default=32)
|
| args = parser.parse_args()
|
| train_data_root = args.train_root
|
| is_use_ray = args.use_ray
|
| n_bit = args.n_bit
|
| atol = args.atol
|
| load_batch_size = args.load_batch_size
|
| compute_batch_size = args.compute_batch_size
|
| folder_list_txt = args.txt
|
| num_cpus = args.num_cpus
|
|
|
| if not n_bit and not atol:
|
| raise ValueError("Must set either n_bit or atol")
|
| if n_bit and atol:
|
| raise ValueError("Cannot set both n_bit and atol")
|
|
|
| if n_bit:
|
| atol = None
|
| if atol:
|
| n_bit = -1
|
|
|
| if folder_list_txt:
|
| with open(folder_list_txt, "r") as f:
|
| check_folders = [line.strip() for line in f.readlines()]
|
| else:
|
| check_folders = None
|
|
|
|
|
|
|
| print("Loading data files...")
|
| data_npz_file_list = load_data_with_prefix(train_data_root, 'data.npz')
|
| data_npz_file_list.sort()
|
| if is_use_ray:
|
| ray.init()
|
| futures = []
|
| graph_list = []
|
| prefix_list = []
|
| for i in tqdm(range(0, len(data_npz_file_list), load_batch_size)):
|
| batch_data_npz_file_list = data_npz_file_list[i: i + load_batch_size]
|
| futures.append(load_and_build_graph_remote.remote(batch_data_npz_file_list, check_folders, n_bit))
|
| for future in tqdm(futures):
|
| result = ray.get(future)
|
| graph_list_batch, prefix_list_batch = result
|
| graph_list.extend(graph_list_batch)
|
| prefix_list.extend(prefix_list_batch)
|
| ray.shutdown()
|
| else:
|
| graph_list, prefix_list = load_and_build_graph(data_npz_file_list, n_bit)
|
| print(f"Loaded {len(graph_list)} data files")
|
|
|
|
|
| graph_node_num = np.array([graph.number_of_nodes() for graph in graph_list])
|
|
|
| identical_pairs_txt = train_data_root + f"_identical_pairs_{n_bit}bit.txt"
|
| fp_identical_pairs = open(identical_pairs_txt, "w")
|
| fp_identical_pairs.close()
|
| novel_txt = train_data_root + f"_novel_{n_bit}bit.txt"
|
| fp_novel = open(novel_txt, "w")
|
| fp_novel.close()
|
|
|
| if is_use_ray:
|
| ray.init(_temp_dir=r"/mnt/d/img2brep/ray_temp")
|
| unique_graph_idx_list = []
|
| pbar = tqdm(range(3, 31))
|
| for num_face in pbar:
|
| print(f"Processing {num_face}")
|
| pbar.set_description(f"Processing {num_face}")
|
| fp_identical_pairs = open(identical_pairs_txt, "a")
|
| fp_novel = open(novel_txt, "a")
|
| print(f"face_num = {num_face}", file=fp_identical_pairs)
|
|
|
| hits_graph_idx = np.where(graph_node_num == num_face)[0]
|
| hits_graph = [graph_list[idx] for idx in tqdm(hits_graph_idx)]
|
| hits_graph_prefix = [prefix_list[idx] for idx in hits_graph_idx]
|
|
|
| if len(hits_graph) != 0:
|
| local_unique_graph_idx_list, identical_pairs = compute_unique(hits_graph, atol, is_use_ray, compute_batch_size)
|
| for unique_graph_idx in local_unique_graph_idx_list:
|
| print(f"{hits_graph_prefix[unique_graph_idx]}", file=fp_novel)
|
|
|
| local_unique_graph_idx_list = [hits_graph_idx[idx] for idx in local_unique_graph_idx_list]
|
| unique_graph_idx_list.extend(local_unique_graph_idx_list)
|
|
|
| if len(identical_pairs) > 0:
|
| for idx1, idx2 in identical_pairs:
|
| print(f"{hits_graph_prefix[idx1]} {hits_graph_prefix[idx2]}", file=fp_identical_pairs)
|
| pbar.set_postfix({"Local Unique": len(local_unique_graph_idx_list) / len(hits_graph),
|
| "Total Unique": len(unique_graph_idx_list) / len(graph_list), })
|
| print(f"Unique: {len(local_unique_graph_idx_list)}/{len(hits_graph_idx)}"
|
| f"={len(local_unique_graph_idx_list) / len(hits_graph_idx)}", file=fp_identical_pairs)
|
| else:
|
| print(f"face_num = {num_face} has no data", file=fp_identical_pairs)
|
| fp_identical_pairs.close()
|
| fp_novel.close()
|
|
|
| if is_use_ray:
|
| ray.shutdown()
|
|
|
| print(f"Unique num: {len(unique_graph_idx_list)}/{len(graph_list)}={len(unique_graph_idx_list) / len(graph_list)}")
|
| print(f"Identical pairs are saved to {identical_pairs_txt}")
|
| print(f"Novel txt are saved to {novel_txt}")
|
| print("Done")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|