| import dataclasses |
| import os |
|
|
| import hydra |
| import numpy as np |
| import torch |
| from flask import Flask, jsonify, request, render_template |
| from flask_cors import CORS |
| from omegaconf import OmegaConf |
| from safetensors.torch import load_model |
| from scipy.spatial.transform import Rotation |
|
|
| from point_sam import build_point_sam |
| import argparse |
|
|
| app = Flask(__name__, static_folder="static") |
| CORS(app) |
|
|
| MAX_POINT_ID = 100 |
| point_info_id = 0 |
| point_info_list = [None for _ in range(MAX_POINT_ID)] |
|
|
| @dataclasses.dataclass |
| class AuxInputs: |
| coords: torch.Tensor |
| features: torch.Tensor |
| centers: torch.Tensor |
| interp_index: torch.Tensor = None |
| interp_weight: torch.Tensor = None |
|
|
| def repeat_interleave(x: torch.Tensor, repeats: int, dim: int): |
| if repeats == 1: |
| return x |
| shape = list(x.shape) |
| shape.insert(dim + 1, 1) |
| shape[dim + 1] = repeats |
| x = x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1) |
| return x |
|
|
|
|
| class PointCloudProcessor: |
| def __init__(self, device="cuda", batch=True, return_tensors="pt"): |
| self.device = device |
| self.batch = batch |
| self.return_tensors = return_tensors |
|
|
| self.center = None |
| self.scale = None |
|
|
| def __call__(self, xyz: np.ndarray, rgb: np.ndarray): |
| |
| |
| |
|
|
| if self.center is None or self.scale is None: |
| self.center = xyz.mean(0) |
| self.scale = np.max(np.linalg.norm(xyz - self.center, axis=-1)) |
|
|
| xyz = (xyz - self.center) / self.scale |
| rgb = ((rgb / 255.0) - 0.5) * 2 |
|
|
| if self.return_tensors == "np": |
| coords = np.float32(xyz) |
| feats = np.float32(rgb) |
| if self.batch: |
| coords = np.expand_dims(coords, 0) |
| feats = np.expand_dims(feats, 0) |
| elif self.return_tensors == "pt": |
| coords = torch.tensor(xyz, dtype=torch.float32, device=self.device) |
| feats = torch.tensor(rgb, dtype=torch.float32, device=self.device) |
| if self.batch: |
| coords = coords.unsqueeze(0) |
| feats = feats.unsqueeze(0) |
| else: |
| raise ValueError(self.return_tensors) |
|
|
| return coords, feats |
|
|
| def normalize(self, xyz): |
| return (xyz - self.center) / self.scale |
|
|
|
|
| class PointCloudSAMPredictor: |
| input_xyz: np.ndarray |
| input_rgb: np.ndarray |
| prompt_coords: list[tuple[float, float, float]] |
| prompt_labels: list[int] |
|
|
| coords: torch.Tensor |
| feats: torch.Tensor |
|
|
| pc_embedding: torch.Tensor |
| patches: dict[str, torch.Tensor] |
| prompt_mask: torch.Tensor |
|
|
| def __init__(self): |
| print("Created model") |
| model = build_point_sam("./model-2.safetensors") |
| model.pc_encoder.patch_embed.grouper.num_groups = 1024 |
| model.pc_encoder.patch_embed.grouper.group_size = 128 |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| model.eval() |
|
|
| self.model = model |
|
|
| self.input_rgb = None |
| self.input_xyz = None |
|
|
| self.input_processor = None |
| self.coords = None |
| self.feats = None |
|
|
| self.pc_embedding = None |
| self.patches = None |
|
|
| self.prompt_coords = None |
| self.prompt_labels = None |
| self.prompt_mask = None |
| self.candidate_index = 0 |
|
|
| @torch.no_grad() |
| def set_pointcloud(self, xyz, rgb): |
| self.input_xyz = xyz |
| self.input_rgb = rgb |
|
|
| self.input_processor = PointCloudProcessor() |
| coords, feats = self.input_processor(xyz, rgb) |
| self.coords = coords |
| self.feats = feats |
|
|
| pc_embedding, patches = self.model.pc_encoder(self.coords, self.feats) |
| self.pc_embedding = pc_embedding |
| self.patches = patches |
| self.prompt_mask = None |
|
|
| def set_prompts(self, prompt_coords, prompt_labels): |
| self.prompt_coords = prompt_coords |
| self.prompt_labels = prompt_labels |
|
|
| @torch.no_grad() |
| def predict_mask(self): |
| normalized_prompt_coords = self.input_processor.normalize( |
| np.array(self.prompt_coords) |
| ) |
| prompt_coords = torch.tensor( |
| normalized_prompt_coords, dtype=torch.float32, device="cuda" |
| ) |
| prompt_labels = torch.tensor( |
| self.prompt_labels, dtype=torch.bool, device="cuda" |
| ) |
| prompt_coords = prompt_coords.reshape(1, -1, 3) |
| prompt_labels = prompt_labels.reshape(1, -1) |
|
|
| multimask_output = prompt_coords.shape[1] == 1 |
|
|
| |
| def decode_masks(coords, feats, pc_embedding, patches, prompt_coords, prompt_labels, prompt_masks, multimask_output): |
| pc_embeddings, patches = pc_embedding, patches |
| centers = patches["centers"] |
| knn_idx = patches["knn_idx"] |
| coords = patches["coords"] |
| feats = patches["feats"] |
| aux_inputs = AuxInputs(coords=coords, features=feats, centers=centers) |
|
|
| pc_pe = self.model.point_encoder.pe_layer(centers) |
| sparse_embeddings = self.model.point_encoder(prompt_coords, prompt_labels) |
| dense_embeddings = self.model.mask_encoder(prompt_masks, coords, centers, knn_idx) |
| dense_embeddings = repeat_interleave( |
| dense_embeddings, sparse_embeddings.shape[0] // dense_embeddings.shape[0], 0 |
| ) |
|
|
| logits, iou_preds = self.model.mask_decoder( |
| pc_embeddings, |
| pc_pe, |
| sparse_embeddings, |
| dense_embeddings, |
| aux_inputs=aux_inputs, |
| multimask_output=multimask_output, |
| ) |
| return logits, iou_preds |
|
|
| logits, scores = decode_masks( |
| self.coords, |
| self.feats, |
| self.pc_embedding, |
| self.patches, |
| prompt_coords, |
| prompt_labels, |
| self.prompt_mask[self.candidate_index].unsqueeze(0) if self.prompt_mask is not None else None, |
| multimask_output, |
| ) |
| logits = logits.squeeze(0) |
| scores = scores.squeeze(0) |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| _, indices = scores.sort(descending=True) |
| logits = logits[indices] |
|
|
| self.prompt_mask = logits |
| self.candidate_index = 0 |
|
|
| return (logits > 0).cpu().numpy() |
|
|
| def set_candidate(self, index): |
| self.candidate_index = index |
|
|
|
|
| predictor = PointCloudSAMPredictor() |
|
|
|
|
| @app.route("/") |
| def index(): |
| return app.send_static_file("index.html") |
|
|
| @app.route("/assets/<path:path>") |
| def assets_route(path): |
| print(path) |
| return app.send_static_file(f"assets/{path}") |
|
|
|
|
| @app.route("/hello_world", methods=["GET"]) |
| def hello_world(): |
| return "Hello, World!" |
|
|
|
|
| @app.route("/set_pointcloud", methods=["POST"]) |
| def set_pointcloud(): |
| request_data = request.get_json() |
| |
| |
| |
|
|
| xyz = request_data["points"] |
| xyz = np.array(xyz).reshape(-1, 3) |
| rgb = request_data["colors"] |
| rgb = np.array(list(rgb)).reshape(-1, 3) |
| predictor.set_pointcloud(xyz, rgb) |
|
|
| pc_embedding = predictor.pc_embedding.cpu() |
| patches = {"centers": predictor.patches["centers"].cpu(), "knn_idx": predictor.patches["knn_idx"].cpu(), "coords": predictor.coords.cpu(), "feats": predictor.feats.cpu()} |
| center = predictor.input_processor.center |
| scale = predictor.input_processor.scale |
|
|
| global point_info_id |
| global point_info_list |
| point_info_list[point_info_id] = {"pc_embedding": pc_embedding, "patches": patches, "center": center, "scale": scale, "prompt_mask": None} |
| |
| return_msg = {"user_id": point_info_id} |
| point_info_id += 1 |
| return jsonify(return_msg) |
| |
|
|
| @app.route("/set_candidate", methods=["POST"]) |
| def set_candidate(): |
| request_data = request.get_json() |
| candidate_index = request_data["index"] |
| predictor.set_candidate(candidate_index) |
| return "success" |
|
|
|
|
| def visualize_pcd_with_prompts(xyz, rgb, prompt_coords, prompt_labels): |
| import trimesh |
|
|
| pcd = trimesh.PointCloud(xyz, rgb) |
| prompt_spheres = [] |
| for i, coord in enumerate(prompt_coords): |
| sphere = trimesh.creation.icosphere() |
| sphere.apply_scale(0.02) |
| sphere.apply_translation(coord) |
| sphere.visual.vertex_colors = [255, 0, 0] if prompt_labels[i] else [0, 255, 0] |
| prompt_spheres.append(sphere) |
|
|
| return trimesh.Scene([pcd] + prompt_spheres) |
|
|
|
|
| @app.route("/set_prompts", methods=["POST"]) |
| def set_prompts(): |
| global point_info_list |
|
|
| request_data = request.get_json() |
| print(request_data.keys()) |
|
|
| |
| prompt_coords = request_data["prompt_coords"] |
| |
| prompt_labels = request_data["prompt_labels"] |
| user_id = request_data["user_id"] |
| print(user_id) |
| point_info = point_info_list[user_id] |
| predictor.pc_embedding = point_info["pc_embedding"].cuda() |
| patches = point_info["patches"] |
| predictor.patches = {"centers": patches["centers"].cuda(), "knn_idx": patches["knn_idx"].cuda(), "coords": patches["coords"].cuda(), "feats": patches["feats"].cuda()} |
| predictor.input_processor.center = point_info["center"] |
| predictor.input_processor.scale = point_info["scale"] |
| if point_info["prompt_mask"] is not None: |
| predictor.prompt_mask = point_info["prompt_mask"].cuda() |
| else: |
| predictor.prompt_mask = None |
| |
| if len(prompt_coords) == 0: |
| predictor.prompt_mask = None |
| pred_mask = np.zeros([len(prompt_coords)], dtype=np.bool_) |
| return jsonify({"mask": pred_mask.tolist()}) |
|
|
| predictor.set_prompts(prompt_coords, prompt_labels) |
| pred_mask = predictor.predict_mask() |
| point_info_list[user_id]["prompt_mask"] = predictor.prompt_mask.cpu() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| return jsonify({"mask": pred_mask.tolist()}) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=7860) |
| args = parser.parse_args() |
| app.run(host=args.host, port=args.port, debug=True) |
|
|