rayli commited on
Commit
73fb7a2
·
verified ·
1 Parent(s): cec3eae

Align exported part colors with kinematic node palette

Browse files
app.py CHANGED
@@ -86,6 +86,7 @@ from instruct_particulate.utils.inference_visualization_utils import (
86
  from instruct_particulate.utils.partfield_feature_utils import (
87
  ensure_partfield_assets_downloaded,
88
  )
 
89
 
90
 
91
  REPO_ROOT = Path(__file__).resolve().parent
@@ -1554,14 +1555,12 @@ UPRIGHT_PICKER_JS = r"""
1554
  }
1555
  """
1556
 
1557
- KINEMATIC_TREE_EDITOR_JS = r"""
 
1558
  () => {
1559
- const palette = [
1560
- "#fca5a5", "#fdba74", "#fde047", "#86efac",
1561
- "#67e8f9", "#93c5fd", "#c4b5fd", "#f0abfc",
1562
- "#f9a8d4", "#a7f3d0", "#fcd34d", "#bfdbfe",
1563
- "#ddd6fe", "#fecaca", "#bbf7d0", "#bae6fd"
1564
- ];
1565
  const NODE_WIDTH = 128;
1566
  const NODE_HEIGHT = 44;
1567
  const PROMPT_POINT_PRECISION = 6;
@@ -2772,6 +2771,7 @@ KINEMATIC_TREE_EDITOR_JS = r"""
2772
  waitForEditor();
2773
  }
2774
  """
 
2775
 
2776
 
2777
  def _extract_gradio_path(value: Any) -> Path | None:
@@ -4870,13 +4870,13 @@ def _kinematic_tree_editor_html() -> str:
4870
  <svg id="kin-tree-edges" class="kin-tree-edge-layer"></svg>
4871
  <div id="kin-tree-joints" class="kin-tree-joint-layer"></div>
4872
  <div id="kin-tree-nodes" class="kin-tree-node-layer">
4873
- <div class="kin-node" style="left: 28px; top: 42px; background: #fca5a5;">
4874
  <div class="kin-node-row">
4875
  <input value="base" aria-label="Link name" />
4876
  <button class="kin-node-delete" title="Delete link" type="button">×</button>
4877
  </div>
4878
  </div>
4879
- <div class="kin-node" style="left: 220px; top: 180px; background: #fdba74;">
4880
  <div class="kin-node-row">
4881
  <input value="moving_part" aria-label="Link name" />
4882
  <button class="kin-node-delete" title="Delete link" type="button">×</button>
 
86
  from instruct_particulate.utils.partfield_feature_utils import (
87
  ensure_partfield_assets_downloaded,
88
  )
89
+ from instruct_particulate.utils.visualization_utils import LINK_COLOR_HEX
90
 
91
 
92
  REPO_ROOT = Path(__file__).resolve().parent
 
1555
  }
1556
  """
1557
 
1558
+ KINEMATIC_TREE_EDITOR_JS = (
1559
+ """
1560
  () => {
1561
+ const palette = """
1562
+ + json.dumps(list(LINK_COLOR_HEX))
1563
+ + r""";
 
 
 
1564
  const NODE_WIDTH = 128;
1565
  const NODE_HEIGHT = 44;
1566
  const PROMPT_POINT_PRECISION = 6;
 
2771
  waitForEditor();
2772
  }
2773
  """
2774
+ )
2775
 
2776
 
2777
  def _extract_gradio_path(value: Any) -> Path | None:
 
4870
  <svg id="kin-tree-edges" class="kin-tree-edge-layer"></svg>
4871
  <div id="kin-tree-joints" class="kin-tree-joint-layer"></div>
4872
  <div id="kin-tree-nodes" class="kin-tree-node-layer">
4873
+ <div class="kin-node" style="left: 28px; top: 42px; background: {LINK_COLOR_HEX[0]};">
4874
  <div class="kin-node-row">
4875
  <input value="base" aria-label="Link name" />
4876
  <button class="kin-node-delete" title="Delete link" type="button">×</button>
4877
  </div>
4878
  </div>
4879
+ <div class="kin-node" style="left: 220px; top: 180px; background: {LINK_COLOR_HEX[1]};">
4880
  <div class="kin-node-row">
4881
  <input value="moving_part" aria-label="Link name" />
4882
  <button class="kin-node-delete" title="Delete link" type="button">×</button>
instruct_particulate/utils/inference_visualization_utils.py CHANGED
@@ -165,7 +165,8 @@ def save_segmented_visualizations(
165
  for part_id in unique_part_ids
166
  ]
167
  mesh_parts_segmented = create_textured_mesh_parts(
168
- [mesh_part.copy() for mesh_part in mesh_parts_original]
 
169
  )
170
  axes = create_motion_axis_meshes(
171
  mesh_parts_original,
 
165
  for part_id in unique_part_ids
166
  ]
167
  mesh_parts_segmented = create_textured_mesh_parts(
168
+ [mesh_part.copy() for mesh_part in mesh_parts_original],
169
+ part_ids=unique_part_ids,
170
  )
171
  axes = create_motion_axis_meshes(
172
  mesh_parts_original,
instruct_particulate/utils/visualization_utils.py CHANGED
@@ -6,29 +6,48 @@ from PIL import Image
6
 
7
  from instruct_particulate.utils.articulation_utils import plucker_to_axis_point
8
 
9
- COLORS = [
10
- (72, 36, 117),
11
- (33, 145, 140),
12
- (189, 223, 38),
13
- (153, 80, 8),
14
- (12, 12, 242),
15
- (242, 12, 150),
16
- (12, 242, 150),
17
- (12, 150, 242)
18
- ]
19
- ARROW_COLOR_REVOLUTE = (255, 0, 0)
20
- ARROW_COLOR_PRISMATIC = (255, 255, 0)
21
-
22
-
23
- def create_textured_mesh_parts(mesh_parts, colors=COLORS, tex_res=256):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Create a texture map with evenly distributed color blocks
25
  # Use a horizontal strip layout: texture height = tex_res, width = num_parts * tex_res
 
 
 
 
 
26
  texture_height = block_width = tex_res
27
  texture_width = len(mesh_parts) * block_width
28
  texture_array = np.zeros((texture_height, texture_width, 3), dtype=np.uint8)
29
 
30
- for i in range(len(mesh_parts)):
31
- color_rgb = colors[i % len(colors)][:3]
32
  x_start = i * block_width
33
  x_end = (i + 1) * block_width
34
  texture_array[:, x_start:x_end] = color_rgb
@@ -255,10 +274,12 @@ def create_motion_axis_meshes(
255
  is_part_prismatic: np.ndarray,
256
  revolute_plucker: np.ndarray,
257
  prismatic_axis: np.ndarray,
 
258
  ):
259
  """Create arrow/ring meshes visualizing predicted joint motion."""
260
  axes = []
261
  for mesh_part, part_id in zip(mesh_parts, unique_part_ids, strict=True):
 
262
  if is_part_revolute[part_id]:
263
  axis_direction, axis_point = plucker_to_axis_point(revolute_plucker[part_id])
264
  arrow_start, arrow_end = get_3D_arrow_on_points(
@@ -271,7 +292,7 @@ def create_motion_axis_meshes(
271
  create_arrow(
272
  arrow_start,
273
  arrow_end,
274
- color=ARROW_COLOR_REVOLUTE,
275
  radius=0.01,
276
  radius_tip=0.018,
277
  )
@@ -283,7 +304,7 @@ def create_motion_axis_meshes(
283
  arrow_direction,
284
  major_radius=0.03,
285
  minor_radius=0.006,
286
- color=ARROW_COLOR_REVOLUTE,
287
  )
288
  )
289
  axes.append(
@@ -292,7 +313,7 @@ def create_motion_axis_meshes(
292
  arrow_direction,
293
  major_radius=0.03,
294
  minor_radius=0.006,
295
- color=ARROW_COLOR_REVOLUTE,
296
  )
297
  )
298
  elif is_part_prismatic[part_id]:
@@ -305,7 +326,7 @@ def create_motion_axis_meshes(
305
  create_arrow(
306
  arrow_start,
307
  arrow_end,
308
- color=ARROW_COLOR_PRISMATIC,
309
  radius=0.01,
310
  radius_tip=0.018,
311
  )
 
6
 
7
  from instruct_particulate.utils.articulation_utils import plucker_to_axis_point
8
 
9
+ LINK_COLOR_HEX = (
10
+ "#fca5a5",
11
+ "#fdba74",
12
+ "#fde047",
13
+ "#86efac",
14
+ "#67e8f9",
15
+ "#93c5fd",
16
+ "#c4b5fd",
17
+ "#f0abfc",
18
+ "#f9a8d4",
19
+ "#a7f3d0",
20
+ "#fcd34d",
21
+ "#bfdbfe",
22
+ "#ddd6fe",
23
+ "#fecaca",
24
+ "#bbf7d0",
25
+ "#bae6fd",
26
+ )
27
+
28
+
29
+ def _hex_to_rgb(color: str) -> tuple[int, int, int]:
30
+ color = color.removeprefix("#")
31
+ return tuple(int(color[index : index + 2], 16) for index in range(0, 6, 2))
32
+
33
+
34
+ COLORS = tuple(_hex_to_rgb(color) for color in LINK_COLOR_HEX)
35
+
36
+
37
+ def create_textured_mesh_parts(mesh_parts, part_ids=None, colors=COLORS, tex_res=256):
38
  # Create a texture map with evenly distributed color blocks
39
  # Use a horizontal strip layout: texture height = tex_res, width = num_parts * tex_res
40
+ part_ids = list(range(len(mesh_parts))) if part_ids is None else list(part_ids)
41
+ if len(part_ids) != len(mesh_parts):
42
+ raise ValueError(
43
+ f"part_ids must align with mesh_parts, got {len(part_ids)} ids for {len(mesh_parts)} meshes"
44
+ )
45
  texture_height = block_width = tex_res
46
  texture_width = len(mesh_parts) * block_width
47
  texture_array = np.zeros((texture_height, texture_width, 3), dtype=np.uint8)
48
 
49
+ for i, part_id in enumerate(part_ids):
50
+ color_rgb = colors[int(part_id) % len(colors)][:3]
51
  x_start = i * block_width
52
  x_end = (i + 1) * block_width
53
  texture_array[:, x_start:x_end] = color_rgb
 
274
  is_part_prismatic: np.ndarray,
275
  revolute_plucker: np.ndarray,
276
  prismatic_axis: np.ndarray,
277
+ colors=COLORS,
278
  ):
279
  """Create arrow/ring meshes visualizing predicted joint motion."""
280
  axes = []
281
  for mesh_part, part_id in zip(mesh_parts, unique_part_ids, strict=True):
282
+ axis_color = colors[int(part_id) % len(colors)][:3]
283
  if is_part_revolute[part_id]:
284
  axis_direction, axis_point = plucker_to_axis_point(revolute_plucker[part_id])
285
  arrow_start, arrow_end = get_3D_arrow_on_points(
 
292
  create_arrow(
293
  arrow_start,
294
  arrow_end,
295
+ color=axis_color,
296
  radius=0.01,
297
  radius_tip=0.018,
298
  )
 
304
  arrow_direction,
305
  major_radius=0.03,
306
  minor_radius=0.006,
307
+ color=axis_color,
308
  )
309
  )
310
  axes.append(
 
313
  arrow_direction,
314
  major_radius=0.03,
315
  minor_radius=0.006,
316
+ color=axis_color,
317
  )
318
  )
319
  elif is_part_prismatic[part_id]:
 
326
  create_arrow(
327
  arrow_start,
328
  arrow_end,
329
+ color=axis_color,
330
  radius=0.01,
331
  radius_tip=0.018,
332
  )