Spaces:
Runtime error
Runtime error
JeffLiang
commited on
Commit
·
fcdbf88
1
Parent(s):
d66c7c7
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +88 -0
- requirements.txt +9 -0
- sam_3d.py +266 -0
- sam_vit_b_01ec64.pth +3 -0
- scannet_data/scene0000_00/color/0.jpg +0 -0
- scannet_data/scene0000_00/color/100.jpg +0 -0
- scannet_data/scene0000_00/color/1000.jpg +0 -0
- scannet_data/scene0000_00/color/1020.jpg +0 -0
- scannet_data/scene0000_00/color/1040.jpg +0 -0
- scannet_data/scene0000_00/color/1060.jpg +0 -0
- scannet_data/scene0000_00/color/1080.jpg +0 -0
- scannet_data/scene0000_00/color/1100.jpg +0 -0
- scannet_data/scene0000_00/color/1120.jpg +0 -0
- scannet_data/scene0000_00/color/1140.jpg +0 -0
- scannet_data/scene0000_00/color/1160.jpg +0 -0
- scannet_data/scene0000_00/color/1180.jpg +0 -0
- scannet_data/scene0000_00/color/120.jpg +0 -0
- scannet_data/scene0000_00/color/1200.jpg +0 -0
- scannet_data/scene0000_00/color/1220.jpg +0 -0
- scannet_data/scene0000_00/color/1240.jpg +0 -0
- scannet_data/scene0000_00/color/1260.jpg +0 -0
- scannet_data/scene0000_00/color/1280.jpg +0 -0
- scannet_data/scene0000_00/color/1300.jpg +0 -0
- scannet_data/scene0000_00/color/1320.jpg +0 -0
- scannet_data/scene0000_00/color/1340.jpg +0 -0
- scannet_data/scene0000_00/color/1360.jpg +0 -0
- scannet_data/scene0000_00/color/1380.jpg +0 -0
- scannet_data/scene0000_00/color/140.jpg +0 -0
- scannet_data/scene0000_00/color/1400.jpg +0 -0
- scannet_data/scene0000_00/color/1420.jpg +0 -0
- scannet_data/scene0000_00/color/1440.jpg +0 -0
- scannet_data/scene0000_00/color/1460.jpg +0 -0
- scannet_data/scene0000_00/color/1480.jpg +0 -0
- scannet_data/scene0000_00/color/1500.jpg +0 -0
- scannet_data/scene0000_00/color/1520.jpg +0 -0
- scannet_data/scene0000_00/color/1540.jpg +0 -0
- scannet_data/scene0000_00/color/1560.jpg +0 -0
- scannet_data/scene0000_00/color/1580.jpg +0 -0
- scannet_data/scene0000_00/color/160.jpg +0 -0
- scannet_data/scene0000_00/color/1600.jpg +0 -0
- scannet_data/scene0000_00/color/1620.jpg +0 -0
- scannet_data/scene0000_00/color/1640.jpg +0 -0
- scannet_data/scene0000_00/color/1660.jpg +0 -0
- scannet_data/scene0000_00/color/1680.jpg +0 -0
- scannet_data/scene0000_00/color/1700.jpg +0 -0
- scannet_data/scene0000_00/color/1720.jpg +0 -0
- scannet_data/scene0000_00/color/1740.jpg +0 -0
- scannet_data/scene0000_00/color/1760.jpg +0 -0
- scannet_data/scene0000_00/color/1780.jpg +0 -0
- scannet_data/scene0000_00/color/180.jpg +0 -0
app.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import ast
|
| 6 |
+
import time
|
| 7 |
+
import random
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from plyfile import PlyData
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import plotly.graph_objs as go
|
| 17 |
+
|
| 18 |
+
from sam_3d import SAM3DDemo
|
| 19 |
+
|
| 20 |
+
def pc_to_plot(pc):
|
| 21 |
+
return go.Figure(
|
| 22 |
+
data=[
|
| 23 |
+
go.Scatter3d(
|
| 24 |
+
x=pc['x'], y=pc['y'], z=pc['z'],
|
| 25 |
+
mode='markers',
|
| 26 |
+
marker=dict(
|
| 27 |
+
size=2,
|
| 28 |
+
color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(pc['red'], pc['green'], pc['blue'])],
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
],
|
| 32 |
+
layout=dict(
|
| 33 |
+
scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))
|
| 34 |
+
),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def inference(scene_name, granularity, coords, plot):
|
| 38 |
+
print(scene_name, coords)
|
| 39 |
+
sam_3d = SAM3DDemo('vit_b', 'sam_vit_b_01ec64.pth', scene_name)
|
| 40 |
+
coords = ast.literal_eval(coords)
|
| 41 |
+
data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final = sam_3d.run_with_coord(coords, int(granularity))
|
| 42 |
+
return pc_to_plot(data_point_select), Image.fromarray(rgb_img_w_points), Image.fromarray(rgb_img_w_masks), pc_to_plot(data_final)
|
| 43 |
+
|
| 44 |
+
plydatas = []
|
| 45 |
+
for scene_name in ['scene0000_00', 'scene0001_00', 'scene0002_00']:
|
| 46 |
+
plydata = PlyData.read(f"./scannet_data/{scene_name}/{scene_name}.ply")
|
| 47 |
+
data = plydata.elements[0].data
|
| 48 |
+
plydatas.append(data)
|
| 49 |
+
|
| 50 |
+
examples = [['scene0000_00', 0, [0, -2.5, 0.7], pc_to_plot(plydatas[0])],
|
| 51 |
+
['scene0001_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[1])],
|
| 52 |
+
['scene0002_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[2])],]
|
| 53 |
+
|
| 54 |
+
title = 'Segment_Anything on 3D in-door point clouds'
|
| 55 |
+
|
| 56 |
+
description = """
|
| 57 |
+
Gradio Demo for Segment Anything on 3D indoor scenes (ScanNet supported). \n
|
| 58 |
+
The logic is straighforward: 1) Find a point in 3D; 2) project the 3D point to valid images; 3) perform 2D SAM on valid images; 4) reproject 2D results back to 3D; 5) Visualization.
|
| 59 |
+
Unfortunatly, it does not support click the point cloud to generate coordinates automatically. You may want to write down the coordinates and put it manually. \n
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
article = """
|
| 63 |
+
<p style='text-align: center'>
|
| 64 |
+
<a href='https://arxiv.org/abs/2210.04150' target='_blank'>
|
| 65 |
+
Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP
|
| 66 |
+
</a>
|
| 67 |
+
|
|
| 68 |
+
<a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p>
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
gr.Interface(
|
| 74 |
+
inference,
|
| 75 |
+
inputs=[
|
| 76 |
+
gr.Dropdown(choices=['scene0000_00', 'scene0001_00', 'scene0002_00'], label="Scannet scene name (limited scenes supported)"),
|
| 77 |
+
gr.Dropdown(choices=[0, 1, 2], label="Mask granularity from 0 (most coarse) to 2 (most precise)"),
|
| 78 |
+
gr.Textbox(lines=1, label='Coordinates'),
|
| 79 |
+
gr.Plot(label="Input Point cloud (For visualization and point finding only, click responce not supported yet.)"),
|
| 80 |
+
],
|
| 81 |
+
outputs=[gr.Plot(label='Selected point(s): red points show the top 10 cloest points for your input anchor point'),
|
| 82 |
+
gr.Image(label='Selected image with projected points'),
|
| 83 |
+
gr.Image(label='Selected image processed after SAM'),
|
| 84 |
+
gr.Plot(label='Output Point cloud: blue points represent the mask')],
|
| 85 |
+
title=title,
|
| 86 |
+
description=description,
|
| 87 |
+
article=article,
|
| 88 |
+
examples=examples).launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
numpy
|
| 3 |
+
plyfile
|
| 4 |
+
plotly
|
| 5 |
+
matplotlib
|
| 6 |
+
opencv-python
|
| 7 |
+
torch==1.10.1+cu113
|
| 8 |
+
torchvision==0.11.2+cu113
|
| 9 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
sam_3d.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import random
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
from plyfile import PlyData
|
| 10 |
+
|
| 11 |
+
from segment_anything import SamPredictor, sam_model_registry
|
| 12 |
+
|
| 13 |
+
def get_image_ids(path):
|
| 14 |
+
files = os.listdir(path)
|
| 15 |
+
files = [f.split('.')[0] for f in files if os.path.isfile(path+'/'+f)] #Filtering only the files.
|
| 16 |
+
return sorted(files)
|
| 17 |
+
|
| 18 |
+
def load_align_matrix_from_txt(path):
|
| 19 |
+
lines = open(path).readlines()
|
| 20 |
+
# test set data doesn't have align_matrix
|
| 21 |
+
axis_align_matrix = np.eye(4)
|
| 22 |
+
for line in lines:
|
| 23 |
+
if 'axisAlignment' in line:
|
| 24 |
+
axis_align_matrix = [
|
| 25 |
+
float(x)
|
| 26 |
+
for x in line.rstrip().strip('axisAlignment = ').split(' ')
|
| 27 |
+
]
|
| 28 |
+
break
|
| 29 |
+
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
|
| 30 |
+
return axis_align_matrix
|
| 31 |
+
|
| 32 |
+
def load_matrix_from_txt(path, shape=(4, 4)):
|
| 33 |
+
with open(path) as f:
|
| 34 |
+
txt = f.readlines()
|
| 35 |
+
txt = ''.join(txt).replace('\n', ' ')
|
| 36 |
+
matrix = [float(v) for v in txt.split()]
|
| 37 |
+
return np.array(matrix).reshape(shape)
|
| 38 |
+
|
| 39 |
+
def load_image(path):
|
| 40 |
+
image = Image.open(path)
|
| 41 |
+
return np.array(image)
|
| 42 |
+
|
| 43 |
+
def convert_from_uvd(u, v, d, intr, pose, align):
|
| 44 |
+
extr = np.linalg.inv(pose)
|
| 45 |
+
if d == 0:
|
| 46 |
+
return None, None, None
|
| 47 |
+
|
| 48 |
+
fx = intr[0, 0]
|
| 49 |
+
fy = intr[1, 1]
|
| 50 |
+
cx = intr[0, 2]
|
| 51 |
+
cy = intr[1, 2]
|
| 52 |
+
depth_scale = 1000
|
| 53 |
+
|
| 54 |
+
z = d / depth_scale
|
| 55 |
+
x = (u - cx) * z / fx
|
| 56 |
+
y = (v - cy) * z / fy
|
| 57 |
+
|
| 58 |
+
world = (align @ pose @ np.array([x, y, z, 1]))
|
| 59 |
+
return world[:3] / world[3]
|
| 60 |
+
|
| 61 |
+
# Find the cloest point in the cloud with select
|
| 62 |
+
def find_closest_point(point, point_cloud, num=1):
|
| 63 |
+
# calculate the Euclidean distances between the input vector and each row of the matrix
|
| 64 |
+
distances = np.linalg.norm(point_cloud - point, axis=1)
|
| 65 |
+
|
| 66 |
+
# find the index of the row with the minimum distance
|
| 67 |
+
closest_index = np.argsort(distances)[:num]
|
| 68 |
+
|
| 69 |
+
# get the closest vector from the matrix
|
| 70 |
+
closest_vector = point_cloud[closest_index]
|
| 71 |
+
|
| 72 |
+
return closest_index, closest_vector
|
| 73 |
+
|
| 74 |
+
def plot_3d(xdata, ydata, zdata, color=None, b_min=2, b_max=8, view=(45, 45)):
|
| 75 |
+
fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, dpi=200)
|
| 76 |
+
ax.view_init(view[0], view[1])
|
| 77 |
+
ax.set_xlim(b_min, b_max)
|
| 78 |
+
ax.set_ylim(b_min, b_max)
|
| 79 |
+
ax.set_zlim(b_min, b_max)
|
| 80 |
+
ax.scatter3D(xdata, ydata, zdata, c=color, cmap='rgb', s=0.1)
|
| 81 |
+
|
| 82 |
+
class SAM3DDemo(object):
|
| 83 |
+
def __init__(self, sam_model, sam_ckpt, scene_name):
|
| 84 |
+
sam = sam_model_registry[sam_model](checkpoint=sam_ckpt).cuda()
|
| 85 |
+
self.predictor = SamPredictor(sam)
|
| 86 |
+
self.scene_name = scene_name
|
| 87 |
+
scene_path = os.path.join('./scannet_data', scene_name)
|
| 88 |
+
self.color_path = os.path.join(scene_path, 'color')
|
| 89 |
+
self.depth_path = os.path.join(scene_path, 'depth')
|
| 90 |
+
self.pose_path = os.path.join(scene_path, 'pose')
|
| 91 |
+
self.intrinsic_path = os.path.join(scene_path, 'intrinsic')
|
| 92 |
+
self.align_matirx_path = f'{scene_path}/{scene_name}.txt'
|
| 93 |
+
self.img_ids = get_image_ids(self.color_path)
|
| 94 |
+
self.align_matrix = load_align_matrix_from_txt(self.align_matirx_path)
|
| 95 |
+
self.intrinsic_depth = load_matrix_from_txt(os.path.join(self.intrinsic_path, 'intrinsic_depth.txt'))
|
| 96 |
+
self.poses = [load_matrix_from_txt(os.path.join(self.pose_path, f'{i}.txt')) for i in self.img_ids]
|
| 97 |
+
self.rgb_images = [load_image(os.path.join(self.color_path, f'{i}.jpg')) for i in self.img_ids]
|
| 98 |
+
self.depth_images = [load_image(os.path.join(self.depth_path, f'{i}.png'))for i in self.img_ids]
|
| 99 |
+
|
| 100 |
+
def project_3D_to_images(self, select_points, valid_margin=20):
|
| 101 |
+
valid_img_ids = []
|
| 102 |
+
valid_points = {}
|
| 103 |
+
for img_i in range(len(self.img_ids)):
|
| 104 |
+
rgb_img = self.rgb_images[img_i]
|
| 105 |
+
depth_img = self.depth_images[img_i]
|
| 106 |
+
extrinsics = self.poses[img_i]
|
| 107 |
+
projection_matrix = self.intrinsic_depth @ np.linalg.inv(self.align_matrix @ extrinsics)
|
| 108 |
+
raw_points = np.vstack((select_points.T, np.ones((1, select_points.T.shape[1]))))
|
| 109 |
+
raw_points = np.dot(projection_matrix, raw_points)
|
| 110 |
+
# bounding simplest
|
| 111 |
+
points = raw_points[:2, :] / raw_points[2, :]
|
| 112 |
+
points = np.round(points).astype(np.int32)
|
| 113 |
+
valid = (points[0] >= valid_margin).all() & (points[1] >= valid_margin).all() \
|
| 114 |
+
& (points[0] < (rgb_img.shape[1] - valid_margin)).all() & (points[1] < (rgb_img.shape[0] - valid_margin)).all() \
|
| 115 |
+
& (raw_points[2, :] > 0).all()
|
| 116 |
+
if valid:
|
| 117 |
+
depth_margin = 0.4
|
| 118 |
+
gt_depths = depth_img[points[1], points[0]] / 1000
|
| 119 |
+
proj_depths = raw_points[2, :]
|
| 120 |
+
if (proj_depths[0] > (1 - depth_margin / 2.0) * gt_depths[0]) & (proj_depths[0] < (1 + depth_margin / 2.0) * gt_depths[0]):
|
| 121 |
+
valid_img_ids.append(img_i)
|
| 122 |
+
valid_points[img_i] = points
|
| 123 |
+
|
| 124 |
+
show_id = valid_img_ids[-1]
|
| 125 |
+
show_points = valid_points[show_id]
|
| 126 |
+
rgb_img = self.rgb_images[show_id]
|
| 127 |
+
fig, ax = plt.subplots()
|
| 128 |
+
ax.imshow(rgb_img)
|
| 129 |
+
for x, y in zip(show_points[0], show_points[1]):
|
| 130 |
+
ax.plot(x, y, 'ro')
|
| 131 |
+
canvas = fig.canvas
|
| 132 |
+
canvas.draw()
|
| 133 |
+
w, h = canvas.get_width_height()
|
| 134 |
+
rgb_img_w_points = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
|
| 135 |
+
print("projecting 3D point to images successfully...")
|
| 136 |
+
return valid_img_ids, valid_points, rgb_img_w_points
|
| 137 |
+
|
| 138 |
+
def process_img_w_sam(self, valid_img_ids, valid_points, granularity):
|
| 139 |
+
mask_colors = []
|
| 140 |
+
for img_i in range(len(self.img_ids)):
|
| 141 |
+
rgb_img = self.rgb_images[img_i]
|
| 142 |
+
msk_color = np.full(rgb_img.shape, 0.5)
|
| 143 |
+
if img_i in valid_img_ids:
|
| 144 |
+
self.predictor.set_image(rgb_img)
|
| 145 |
+
point_coor = valid_points[img_i].T[0][None]
|
| 146 |
+
masks, _, _ = self.predictor.predict(point_coords=point_coor, point_labels=np.array([1]))
|
| 147 |
+
# fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
|
| 148 |
+
# for i in range(3):
|
| 149 |
+
# mask_img = masks[i][:,:,None] * rgb_img
|
| 150 |
+
# axs[i].set_title(f'granularity {i}')
|
| 151 |
+
# axs[i].imshow(mask_img)
|
| 152 |
+
m = masks[granularity]
|
| 153 |
+
msk_color[m] = [0, 0, 1.0]
|
| 154 |
+
mask_colors.append(msk_color)
|
| 155 |
+
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(24, 8))
|
| 156 |
+
for i in range(3):
|
| 157 |
+
mask_img = masks[i][:,:,None] * rgb_img
|
| 158 |
+
axs[i].set_title(f'granularity {i}')
|
| 159 |
+
axs[i].imshow(mask_img)
|
| 160 |
+
canvas = fig.canvas
|
| 161 |
+
canvas.draw()
|
| 162 |
+
w, h = canvas.get_width_height()
|
| 163 |
+
rgb_img_w_masks = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
|
| 164 |
+
print("processing images with SAM successfully...")
|
| 165 |
+
return mask_colors, rgb_img_w_masks
|
| 166 |
+
|
| 167 |
+
def project_mask_to_3d(self, mask_colors, sample_ratio=0.002):
|
| 168 |
+
x_data, y_data, z_data, c_data = [], [], [], []
|
| 169 |
+
for img_i in range(len(self.img_ids)):
|
| 170 |
+
id = self.img_ids[img_i]
|
| 171 |
+
# RGBD
|
| 172 |
+
d = self.depth_images[img_i]
|
| 173 |
+
c = self.rgb_images[img_i]
|
| 174 |
+
p = self.poses[img_i]
|
| 175 |
+
msk_color = mask_colors[img_i]
|
| 176 |
+
# Projecting RGB features into the point space
|
| 177 |
+
for i in range(d.shape[0]):
|
| 178 |
+
for j in range(d.shape[1]):
|
| 179 |
+
if random.random() < sample_ratio:
|
| 180 |
+
x, y, z = convert_from_uvd(j, i, d[i, j], self.intrinsic_depth, p, self.align_matrix)
|
| 181 |
+
if x is None:
|
| 182 |
+
continue
|
| 183 |
+
x_data.append(x)
|
| 184 |
+
y_data.append(y)
|
| 185 |
+
z_data.append(z)
|
| 186 |
+
ci = int(i * c.shape[0] / d.shape[0])
|
| 187 |
+
cj = int(j * c.shape[1] / d.shape[1])
|
| 188 |
+
c_data.append([msk_color[ci, cj]])
|
| 189 |
+
print("reprojecting images to 3D points successfully...")
|
| 190 |
+
return x_data, y_data, z_data, c_data
|
| 191 |
+
|
| 192 |
+
def match_projected_point_to_gt_point(self, x_data, y_data, z_data, c_data, gt_coords):
|
| 193 |
+
|
| 194 |
+
c_data = torch.tensor(np.concatenate(c_data, axis=0))
|
| 195 |
+
img_coords = np.array([x_data, y_data, z_data], dtype=np.float32).T
|
| 196 |
+
gt_quant_coords = np.floor_divide(gt_coords, 0.2)
|
| 197 |
+
img_quant_coords = np.floor_divide(img_coords, 0.2)
|
| 198 |
+
|
| 199 |
+
# Remove the reduandant coords
|
| 200 |
+
unique_gt_coords, gt_inverse_indices = np.unique(gt_quant_coords, axis=0, return_inverse=True)
|
| 201 |
+
unique_img_coords, img_inverse_indices = np.unique(img_quant_coords, axis=0, return_inverse=True)
|
| 202 |
+
|
| 203 |
+
# Match the coords in gt_coords to img_corrds
|
| 204 |
+
def find_loc(vec):
|
| 205 |
+
obj = np.empty((), dtype=object)
|
| 206 |
+
out = np.where((unique_img_coords == vec).all(1))[0]
|
| 207 |
+
obj[()] = out
|
| 208 |
+
return obj
|
| 209 |
+
|
| 210 |
+
gt_2_img_map = np.apply_along_axis(find_loc, 1, unique_gt_coords)
|
| 211 |
+
# Since some places are empty, using the simple round interplation
|
| 212 |
+
gt_2_img_map_filled = []
|
| 213 |
+
start_id = np.array([0])
|
| 214 |
+
for loc in gt_2_img_map:
|
| 215 |
+
if not np.any(loc):
|
| 216 |
+
loc = start_id
|
| 217 |
+
else:
|
| 218 |
+
start_id = loc
|
| 219 |
+
gt_2_img_map_filled.append(int(loc))
|
| 220 |
+
|
| 221 |
+
mean_colors = []
|
| 222 |
+
for i in range(unique_img_coords.shape[0]):
|
| 223 |
+
valid_locs = np.where(img_inverse_indices == i)
|
| 224 |
+
mean_f = torch.mean(c_data[valid_locs], axis=0)
|
| 225 |
+
# mean_f, _ = torch.mode(c_data[valid_locs], dim=0)
|
| 226 |
+
mean_colors.append(mean_f.unsqueeze(0))
|
| 227 |
+
mean_colors = torch.cat(mean_colors)
|
| 228 |
+
# Project the averaged features back to groundtruth point clouds
|
| 229 |
+
img_2_gt_colors = mean_colors[gt_2_img_map_filled]
|
| 230 |
+
projected_gt_colors = img_2_gt_colors[gt_inverse_indices]
|
| 231 |
+
print("convert projected points to GT points successfully...")
|
| 232 |
+
|
| 233 |
+
return projected_gt_colors
|
| 234 |
+
|
| 235 |
+
def render_point_cloud(self, data, color):
|
| 236 |
+
data_copy = copy.copy(data)
|
| 237 |
+
uint_color = torch.round(torch.tensor(color) * 255).to(torch.uint8)
|
| 238 |
+
data_copy['red'] = uint_color[:, 0]
|
| 239 |
+
data_copy['green'] = uint_color[:, 1]
|
| 240 |
+
data_copy['blue'] = uint_color[:, 2]
|
| 241 |
+
return data_copy
|
| 242 |
+
|
| 243 |
+
def run_with_coord(self, point, granularity):
|
| 244 |
+
x_data, y_data, z_data, c_data = [], [], [], []
|
| 245 |
+
|
| 246 |
+
plydata = PlyData.read(f"./scannet_data/{self.scene_name}/{self.scene_name}.ply")
|
| 247 |
+
data = plydata.elements[0].data
|
| 248 |
+
|
| 249 |
+
# gt_coords stand for the groudtruth point clouds coordinates
|
| 250 |
+
gt_coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
|
| 251 |
+
gt_color = np.array([data['red'], data['green'], data['blue']], dtype=np.float32).T
|
| 252 |
+
blank_color = np.full(gt_color.shape, 0.5)
|
| 253 |
+
|
| 254 |
+
select_index, select_points = find_closest_point(point, gt_coords, num=10)
|
| 255 |
+
point_select_color = blank_color.copy()
|
| 256 |
+
point_select_color[select_index] = [1.0, 0, 0]
|
| 257 |
+
data_point_select = self.render_point_cloud(data, point_select_color)
|
| 258 |
+
|
| 259 |
+
valid_img_ids, valid_points, rgb_img_w_points = self.project_3D_to_images(select_points)
|
| 260 |
+
mask_colors, rgb_img_w_masks = self.process_img_w_sam(valid_img_ids, valid_points, granularity)
|
| 261 |
+
x_data, y_data, z_data, c_data = self.project_mask_to_3d(mask_colors)
|
| 262 |
+
projected_gt_colors = self.match_projected_point_to_gt_point(x_data, y_data, z_data, c_data, gt_coords)
|
| 263 |
+
|
| 264 |
+
data_final = self.render_point_cloud(data, projected_gt_colors)
|
| 265 |
+
|
| 266 |
+
return data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final
|
sam_vit_b_01ec64.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
|
| 3 |
+
size 375042383
|
scannet_data/scene0000_00/color/0.jpg
ADDED
|
scannet_data/scene0000_00/color/100.jpg
ADDED
|
scannet_data/scene0000_00/color/1000.jpg
ADDED
|
scannet_data/scene0000_00/color/1020.jpg
ADDED
|
scannet_data/scene0000_00/color/1040.jpg
ADDED
|
scannet_data/scene0000_00/color/1060.jpg
ADDED
|
scannet_data/scene0000_00/color/1080.jpg
ADDED
|
scannet_data/scene0000_00/color/1100.jpg
ADDED
|
scannet_data/scene0000_00/color/1120.jpg
ADDED
|
scannet_data/scene0000_00/color/1140.jpg
ADDED
|
scannet_data/scene0000_00/color/1160.jpg
ADDED
|
scannet_data/scene0000_00/color/1180.jpg
ADDED
|
scannet_data/scene0000_00/color/120.jpg
ADDED
|
scannet_data/scene0000_00/color/1200.jpg
ADDED
|
scannet_data/scene0000_00/color/1220.jpg
ADDED
|
scannet_data/scene0000_00/color/1240.jpg
ADDED
|
scannet_data/scene0000_00/color/1260.jpg
ADDED
|
scannet_data/scene0000_00/color/1280.jpg
ADDED
|
scannet_data/scene0000_00/color/1300.jpg
ADDED
|
scannet_data/scene0000_00/color/1320.jpg
ADDED
|
scannet_data/scene0000_00/color/1340.jpg
ADDED
|
scannet_data/scene0000_00/color/1360.jpg
ADDED
|
scannet_data/scene0000_00/color/1380.jpg
ADDED
|
scannet_data/scene0000_00/color/140.jpg
ADDED
|
scannet_data/scene0000_00/color/1400.jpg
ADDED
|
scannet_data/scene0000_00/color/1420.jpg
ADDED
|
scannet_data/scene0000_00/color/1440.jpg
ADDED
|
scannet_data/scene0000_00/color/1460.jpg
ADDED
|
scannet_data/scene0000_00/color/1480.jpg
ADDED
|
scannet_data/scene0000_00/color/1500.jpg
ADDED
|
scannet_data/scene0000_00/color/1520.jpg
ADDED
|
scannet_data/scene0000_00/color/1540.jpg
ADDED
|
scannet_data/scene0000_00/color/1560.jpg
ADDED
|
scannet_data/scene0000_00/color/1580.jpg
ADDED
|
scannet_data/scene0000_00/color/160.jpg
ADDED
|
scannet_data/scene0000_00/color/1600.jpg
ADDED
|
scannet_data/scene0000_00/color/1620.jpg
ADDED
|
scannet_data/scene0000_00/color/1640.jpg
ADDED
|
scannet_data/scene0000_00/color/1660.jpg
ADDED
|
scannet_data/scene0000_00/color/1680.jpg
ADDED
|
scannet_data/scene0000_00/color/1700.jpg
ADDED
|
scannet_data/scene0000_00/color/1720.jpg
ADDED
|
scannet_data/scene0000_00/color/1740.jpg
ADDED
|
scannet_data/scene0000_00/color/1760.jpg
ADDED
|
scannet_data/scene0000_00/color/1780.jpg
ADDED
|
scannet_data/scene0000_00/color/180.jpg
ADDED
|