File size: 31,608 Bytes
2dd4628
 
 
2f3ab6d
 
2dd4628
 
 
 
2f3ab6d
2dd4628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f2214c
 
 
 
 
 
 
 
 
 
2dd4628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f3ab6d
2dd4628
 
 
 
 
 
 
4f2214c
 
 
2dd4628
 
 
4f2214c
 
 
 
 
 
 
2dd4628
 
 
 
2f3ab6d
 
 
4f2214c
2f3ab6d
4f2214c
 
 
 
2f3ab6d
 
 
 
4f2214c
2f3ab6d
 
 
 
 
 
 
2dd4628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f3ab6d
2dd4628
2f3ab6d
2dd4628
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
import json
import os
import re

import numpy as np
import trimesh

from instruct_particulate.utils.articulation_utils import (
    compute_part_transforms,
    plucker_to_axis_point,
)


def _sanitize_link_name(name: str) -> str:
    sanitized = re.sub(r"[^A-Za-z0-9_]+", "_", str(name).strip())
    sanitized = re.sub(r"_+", "_", sanitized).strip("_")
    return sanitized or "link"


def _deduplicate_link_names(link_name_by_part_id):
    seen_counts = {}
    deduplicated = {}
    for part_id, link_name in link_name_by_part_id.items():
        base_name = _sanitize_link_name(link_name)
        duplicate_index = seen_counts.get(base_name, 0)
        seen_counts[base_name] = duplicate_index + 1
        if duplicate_index == 0:
            deduplicated[part_id] = base_name
        else:
            deduplicated[part_id] = f"{base_name}_{duplicate_index}"
    return deduplicated


def _resolve_link_names(unique_part_ids, link_names=None):
    if link_names is None:
        return {int(pid): f"link_{int(pid)}" for pid in unique_part_ids}
    if isinstance(link_names, dict):
        resolved = {}
        for pid in unique_part_ids:
            resolved[int(pid)] = link_names.get(int(pid), f"link_{int(pid)}")
        return _deduplicate_link_names(resolved)
    if len(link_names) <= max(int(pid) for pid in unique_part_ids):
        raise ValueError("link_names does not cover all requested part IDs")
    return _deduplicate_link_names(
        {int(pid): link_names[int(pid)] for pid in unique_part_ids}
    )


def _mesh_has_embedded_visual_material(mesh) -> bool:
    visual = getattr(mesh, "visual", None)
    if visual is None:
        return False
    if getattr(visual, "kind", None) in {"texture", "vertex", "face"}:
        return True
    material = getattr(visual, "material", None)
    return material is not None


def export_animated_glb_file(
    mesh_parts,
    unique_part_ids,
    motion_hierarchy,
    is_part_revolute,
    is_part_prismatic,
    revolute_plucker,
    revolute_range,
    prismatic_axis,
    prismatic_range,
    animation_frames,
    output_path,
    include_axes=False,
    axes_meshes=None
):
    """
    Export an animated GLB file with proper node transformations.
    
    This function creates a GLB file with baked animations where each mesh part is a separate node
    with transformation animations (translation, rotation, scale) that represent the articulation
    motion over time.
    
    Args:
        mesh_parts: List of trimesh objects, one per part
        unique_part_ids: Array of unique part IDs
        motion_hierarchy: List of (parent_id, child_id) tuples defining the kinematic tree
        is_part_revolute: Boolean array indicating if each part has a revolute joint
        is_part_prismatic: Boolean array indicating if each part has a prismatic joint
        revolute_plucker: Plucker coordinates for revolute joint axes
        revolute_range: [low, high] angle limits for revolute joints
        prismatic_axis: Direction vectors for prismatic joints
        prismatic_range: [low, high] displacement limits for prismatic joints
        animation_frames: Number of keyframes in the animation
        output_path: Path to the output animated GLB file
        include_axes: Whether to include axis visualization meshes
        axes_meshes: List of trimesh objects representing axis visualizations (arrows/rings)
    
    The animation interpolates linearly from the low limit (state=0) to high limit (state=1)
    over the specified number of frames at 30 FPS.
    """
    import tempfile
    from pygltflib import GLTF2, Animation, AnimationChannel, AnimationSampler, Accessor, BufferView
    
    # Step 1: Export base mesh using trimesh (which handles textures/UVs correctly)
    # Create a Scene with all parts and axes
    scene = trimesh.Scene()
    
    # Keep track of part names to find their node indices later
    part_node_names = []
    
    for i, mesh_part in enumerate(mesh_parts):
        # Assign a unique name for this part
        # We use a specific prefix to identify it later
        node_name = f"part_node_{i}"
        part_node_names.append(node_name)
        scene.add_geometry(mesh_part, node_name=node_name)
    
    if include_axes and axes_meshes:
        for i, axis_mesh in enumerate(axes_meshes):
            scene.add_geometry(axis_mesh, node_name=f"axis_node_{i}")
    
    # Export to a temporary file using trimesh
    with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as tmp:
        tmp_path = tmp.name
    
    try:
        scene.export(tmp_path)
        
        # Step 2: Load the GLB using pygltflib
        gltf = GLTF2().load(tmp_path)
        
        # Map node names to node indices
        node_name_to_idx = {}
        if gltf.nodes:
            for i, node in enumerate(gltf.nodes):
                if node.name:
                    node_name_to_idx[node.name] = i
        
        # Step 3: Add animation data
        if not gltf.animations:
            gltf.animations = []
        gltf.animations.append(Animation(channels=[], samplers=[]))
    
        animation_idx = len(gltf.animations) - 1
        
        # Get the current binary buffer
        # Read it from the file directly to ensure we have the correct data
        with open(tmp_path, 'rb') as f:
            # GLB format: 12-byte header, then chunks
            f.read(12)
            # Read JSON chunk
            json_chunk_length = int.from_bytes(f.read(4), byteorder='little')
            f.read(4)
            f.read(json_chunk_length)
            # Read binary chunk
            bin_chunk_length = int.from_bytes(f.read(4), byteorder='little')
            f.read(4)
            binary_data = bytearray(f.read(bin_chunk_length))
        
        # Helper function to add binary data to the GLB buffer
        def add_to_binary(data_bytes):
            """Add data to binary blob and return BufferView info."""
            nonlocal binary_data
            
            # Align to 4 bytes
            while len(binary_data) % 4 != 0:
                binary_data.append(0)
            
            start = len(binary_data)
            binary_data.extend(data_bytes)
            
            # Update buffer length in gltf structure
            gltf.buffers[0].byteLength = len(binary_data)
            
            return start, len(data_bytes)
        
        # Step 4: Create animation data
        states = np.linspace(0, 1, animation_frames)
        times = np.linspace(0, animation_frames / 30.0, animation_frames).astype(np.float32)  # 30 FPS
        
        # Add time accessor
        time_bytes = times.tobytes()
        time_start, time_length = add_to_binary(time_bytes)
        time_bv_idx = len(gltf.bufferViews)
        gltf.bufferViews.append(BufferView(
            buffer=0,
            byteOffset=time_start,
            byteLength=time_length
        ))
        
        time_acc_idx = len(gltf.accessors)
        gltf.accessors.append(Accessor(
            bufferView=time_bv_idx,
            componentType=5126,  # FLOAT
            count=len(times),
            type='SCALAR',
            max=[float(times.max())],
            min=[float(times.min())]
        ))
        
        # For each part, create TRS animation samplers
        for part_idx, part_id in enumerate(unique_part_ids):
            # Find the correct node index for this part
            part_node_name = part_node_names[part_idx]
            target_node_idx = node_name_to_idx.get(part_node_name)
            
            if target_node_idx is None:
                print(f"Warning: Could not find node index for part {part_idx} (name: {part_node_name})")
                continue

            # Compute transforms for all frames
            transforms_over_time = []
            for state in states:
                transforms = compute_part_transforms(
                    unique_part_ids,
                    motion_hierarchy,
                    is_part_revolute,
                    is_part_prismatic,
                    revolute_plucker,
                    revolute_range,
                    prismatic_axis,
                    prismatic_range,
                    state
                )
                transforms_over_time.append(transforms[part_id])
            
            # Decompose transforms into TRS
            translations = []
            rotations = []
            scales = []
            
            for T in transforms_over_time:
                # Extract translation
                translation = T[:3, 3]
                translations.append(translation)
                
                # Extract rotation (convert to quaternion)
                R = T[:3, :3]
                # Compute scale
                scale = np.array([
                    np.linalg.norm(R[:, 0]),
                    np.linalg.norm(R[:, 1]),
                    np.linalg.norm(R[:, 2])
                ])
                scales.append(scale)
                
                # Remove scale from rotation matrix
                R_normalized = R / scale
                
                # Convert rotation matrix to quaternion
                trace = np.trace(R_normalized)
                if trace > 0:
                    s = 0.5 / np.sqrt(trace + 1.0)
                    w = 0.25 / s
                    x = (R_normalized[2, 1] - R_normalized[1, 2]) * s
                    y = (R_normalized[0, 2] - R_normalized[2, 0]) * s
                    z = (R_normalized[1, 0] - R_normalized[0, 1]) * s
                else:
                    if R_normalized[0, 0] > R_normalized[1, 1] and R_normalized[0, 0] > R_normalized[2, 2]:
                        s = 2.0 * np.sqrt(1.0 + R_normalized[0, 0] - R_normalized[1, 1] - R_normalized[2, 2])
                        w = (R_normalized[2, 1] - R_normalized[1, 2]) / s
                        x = 0.25 * s
                        y = (R_normalized[0, 1] + R_normalized[1, 0]) / s
                        z = (R_normalized[0, 2] + R_normalized[2, 0]) / s
                    elif R_normalized[1, 1] > R_normalized[2, 2]:
                        s = 2.0 * np.sqrt(1.0 + R_normalized[1, 1] - R_normalized[0, 0] - R_normalized[2, 2])
                        w = (R_normalized[0, 2] - R_normalized[2, 0]) / s
                        x = (R_normalized[0, 1] + R_normalized[1, 0]) / s
                        y = 0.25 * s
                        z = (R_normalized[1, 2] + R_normalized[2, 1]) / s
                    else:
                        s = 2.0 * np.sqrt(1.0 + R_normalized[2, 2] - R_normalized[0, 0] - R_normalized[1, 1])
                        w = (R_normalized[1, 0] - R_normalized[0, 1]) / s
                        x = (R_normalized[0, 2] + R_normalized[2, 0]) / s
                        y = (R_normalized[1, 2] + R_normalized[2, 1]) / s
                        z = 0.25 * s
                
                rotations.append([x, y, z, w])
            
            translations = np.array(translations, dtype=np.float32)
            rotations = np.array(rotations, dtype=np.float32)
            scales = np.array(scales, dtype=np.float32)
            
            # Add translation accessor
            trans_bytes = translations.tobytes()
            trans_start, trans_length = add_to_binary(trans_bytes)
            trans_bv_idx = len(gltf.bufferViews)
            gltf.bufferViews.append(BufferView(
                buffer=0,
                byteOffset=trans_start,
                byteLength=trans_length
            ))
            
            trans_acc_idx = len(gltf.accessors)
            gltf.accessors.append(Accessor(
                bufferView=trans_bv_idx,
                componentType=5126,
                count=len(translations),
                type='VEC3',
                max=translations.max(axis=0).tolist(),
                min=translations.min(axis=0).tolist()
            ))
            
            # Add rotation accessor
            rot_bytes = rotations.tobytes()
            rot_start, rot_length = add_to_binary(rot_bytes)
            rot_bv_idx = len(gltf.bufferViews)
            gltf.bufferViews.append(BufferView(
                buffer=0,
                byteOffset=rot_start,
                byteLength=rot_length
            ))
            
            rot_acc_idx = len(gltf.accessors)
            gltf.accessors.append(Accessor(
                bufferView=rot_bv_idx,
                componentType=5126,
                count=len(rotations),
                type='VEC4',
                max=rotations.max(axis=0).tolist(),
                min=rotations.min(axis=0).tolist()
            ))
            
            # Add scale accessor
            scale_bytes = scales.tobytes()
            scale_start, scale_length = add_to_binary(scale_bytes)
            scale_bv_idx = len(gltf.bufferViews)
            gltf.bufferViews.append(BufferView(
                buffer=0,
                byteOffset=scale_start,
                byteLength=scale_length
            ))
            
            scale_acc_idx = len(gltf.accessors)
            gltf.accessors.append(Accessor(
                bufferView=scale_bv_idx,
                componentType=5126,
                count=len(scales),
                type='VEC3',
                max=scales.max(axis=0).tolist(),
                min=scales.min(axis=0).tolist()
            ))
            
            # Create animation samplers and channels
            # Translation sampler
            trans_sampler_idx = len(gltf.animations[animation_idx].samplers)
            gltf.animations[animation_idx].samplers.append(AnimationSampler(
                input=time_acc_idx,
                output=trans_acc_idx,
                interpolation='LINEAR'
            ))
            gltf.animations[animation_idx].channels.append(AnimationChannel(
                sampler=trans_sampler_idx,
                target={'node': target_node_idx, 'path': 'translation'}
            ))
            
            # Rotation sampler
            rot_sampler_idx = len(gltf.animations[animation_idx].samplers)
            gltf.animations[animation_idx].samplers.append(AnimationSampler(
                input=time_acc_idx,
                output=rot_acc_idx,
                interpolation='LINEAR'
            ))
            gltf.animations[animation_idx].channels.append(AnimationChannel(
                sampler=rot_sampler_idx,
                target={'node': target_node_idx, 'path': 'rotation'}
            ))
            
            # Scale sampler
            scale_sampler_idx = len(gltf.animations[animation_idx].samplers)
            gltf.animations[animation_idx].samplers.append(AnimationSampler(
                input=time_acc_idx,
                output=scale_acc_idx,
                interpolation='LINEAR'
            ))
            gltf.animations[animation_idx].channels.append(AnimationChannel(
                sampler=scale_sampler_idx,
                target={'node': target_node_idx, 'path': 'scale'}
            ))
    
        # Step 5: Save the animated GLB with updated binary data
        # We need to manually write the GLB file to ensure our binary_data is used        
        # Helper function to recursively convert non-serializable objects to dicts
        def make_json_serializable(obj):
            """Recursively convert objects to JSON-serializable format."""
            # Handle numpy arrays and scalars
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, (np.integer, np.floating)):
                return obj.item()
            elif isinstance(obj, np.bool_):
                return bool(obj)
            # Handle objects with __dict__ (like Attributes)
            elif hasattr(obj, '__dict__') and not isinstance(obj, (str, bytes, type)):
                result = {}
                for key, value in obj.__dict__.items():
                    if not key.startswith('_'):  # Skip private attributes
                        result[key] = make_json_serializable(value)
                return result
            elif isinstance(obj, dict):
                return {k: make_json_serializable(v) for k, v in obj.items()}
            elif isinstance(obj, (list, tuple)):
                return [make_json_serializable(item) for item in obj]
            elif hasattr(obj, 'to_dict') and callable(getattr(obj, 'to_dict')):
                # Handle objects with to_dict method
                return make_json_serializable(obj.to_dict())
            else:
                # Return primitive types as-is (str, int, float, bool, None)
                return obj
        
        # Helper function to clean GLTF dict by removing null values and empty arrays
        def clean_gltf_dict(obj):
            """Remove null values and empty arrays to comply with GLTF spec."""
            if isinstance(obj, dict):
                result = {}
                for key, value in obj.items():
                    cleaned_value = clean_gltf_dict(value)
                    # Skip null values (GLTF spec: optional fields should be omitted, not null)
                    if cleaned_value is None:
                        continue
                    # Skip empty arrays (GLTF spec: empty arrays should be omitted)
                    if isinstance(cleaned_value, list) and len(cleaned_value) == 0:
                        continue
                    result[key] = cleaned_value
                return result
            elif isinstance(obj, list):
                cleaned_list = [clean_gltf_dict(item) for item in obj]
                # Filter out None values from lists
                return [item for item in cleaned_list if item is not None]
            else:
                return obj
        
        # Helper function to validate and fix mesh primitives
        def validate_mesh_primitives(gltf_dict):
            """Remove invalid accessor indices from mesh primitives."""
            if 'meshes' not in gltf_dict:
                return gltf_dict
            
            num_accessors = len(gltf_dict.get('accessors', []))
            
            for mesh in gltf_dict['meshes']:
                if 'primitives' not in mesh:
                    continue
                for primitive in mesh['primitives']:
                    if 'attributes' not in primitive:
                        continue
                    # Remove invalid attribute references
                    valid_attributes = {}
                    for attr_name, accessor_idx in primitive['attributes'].items():
                        # Only keep attributes with valid accessor indices
                        if (isinstance(accessor_idx, int) and 
                            accessor_idx >= 0 and 
                            accessor_idx < num_accessors):
                            valid_attributes[attr_name] = accessor_idx
                    primitive['attributes'] = valid_attributes
                    
                    # Validate indices accessor if present
                    if 'indices' in primitive:
                        indices_idx = primitive['indices']
                        if not (isinstance(indices_idx, int) and 
                               indices_idx >= 0 and 
                               indices_idx < num_accessors):
                            del primitive['indices']
                    
                    # Validate material index if present
                    if 'material' in primitive:
                        material_idx = primitive['material']
                        num_materials = len(gltf_dict.get('materials', []))
                        if not (isinstance(material_idx, int) and 
                               material_idx >= 0 and 
                               material_idx < num_materials):
                            del primitive['material']
            
            return gltf_dict
        
        # Helper function to validate node references
        def validate_node_references(gltf_dict):
            """Validate and fix node references to other objects."""
            if 'nodes' not in gltf_dict:
                return gltf_dict
            
            num_meshes = len(gltf_dict.get('meshes', []))
            num_cameras = len(gltf_dict.get('cameras', []))
            num_skins = len(gltf_dict.get('skins', []))
            num_nodes = len(gltf_dict['nodes'])
            
            for node in gltf_dict['nodes']:
                # Validate mesh reference
                if 'mesh' in node:
                    mesh_idx = node['mesh']
                    if not (isinstance(mesh_idx, int) and 
                           mesh_idx >= 0 and 
                           mesh_idx < num_meshes):
                        del node['mesh']
                
                # Validate camera reference
                if 'camera' in node:
                    camera_idx = node['camera']
                    if not (isinstance(camera_idx, int) and 
                           camera_idx >= 0 and 
                           camera_idx < num_cameras):
                        del node['camera']
                
                # Validate skin reference
                if 'skin' in node:
                    skin_idx = node['skin']
                    if not (isinstance(skin_idx, int) and 
                           skin_idx >= 0 and 
                           skin_idx < num_skins):
                        del node['skin']
                
                # Validate children references
                if 'children' in node:
                    valid_children = []
                    for child_idx in node['children']:
                        if (isinstance(child_idx, int) and 
                            child_idx >= 0 and 
                            child_idx < num_nodes):
                            valid_children.append(child_idx)
                    if len(valid_children) > 0:
                        node['children'] = valid_children
                    else:
                        del node['children']
            
            return gltf_dict
        
        # Helper function to validate texture and image references
        def validate_texture_references(gltf_dict):
            """Validate and fix texture and image references."""
            num_images = len(gltf_dict.get('images', []))
            num_samplers = len(gltf_dict.get('samplers', []))
            
            # Validate textures
            if 'textures' in gltf_dict:
                for texture in gltf_dict['textures']:
                    # Validate sampler reference
                    if 'sampler' in texture:
                        sampler_idx = texture['sampler']
                        if not (isinstance(sampler_idx, int) and 
                               sampler_idx >= 0 and 
                               sampler_idx < num_samplers):
                            del texture['sampler']
                    
                    # Validate source (image) reference
                    if 'source' in texture:
                        source_idx = texture['source']
                        if not (isinstance(source_idx, int) and 
                               source_idx >= 0 and 
                               source_idx < num_images):
                            del texture['source']
            
            return gltf_dict
        
        # Update JSON to reflect new buffer size
        gltf_dict = gltf.to_dict()
        # Recursively convert all nested objects to be JSON serializable
        gltf_dict = make_json_serializable(gltf_dict)
        # Validate and fix references
        gltf_dict = validate_mesh_primitives(gltf_dict)
        gltf_dict = validate_node_references(gltf_dict)
        gltf_dict = validate_texture_references(gltf_dict)
        # Clean up null values and empty arrays (must be last to remove invalid fields)
        gltf_dict = clean_gltf_dict(gltf_dict)
        
        # Write GLB file manually
        with open(output_path, 'wb') as f:
            # Write GLB header
            # Magic: "glTF"
            f.write(b'glTF')
            # Version: 2
            f.write((2).to_bytes(4, byteorder='little'))
            # Total length (will update later)
            total_length_pos = f.tell()
            f.write((0).to_bytes(4, byteorder='little'))
            
            # Write JSON chunk
            json_str = json.dumps(gltf_dict, separators=(',', ':'))
            json_bytes = json_str.encode('utf-8')
            json_chunk_length = len(json_bytes)
            # Align JSON to 4 bytes
            while json_chunk_length % 4 != 0:
                json_bytes += b' '
                json_chunk_length += 1
            
            f.write(json_chunk_length.to_bytes(4, byteorder='little'))
            f.write(b'JSON')
            f.write(json_bytes)
            
            # Write binary chunk
            # Align binary to 4 bytes
            while len(binary_data) % 4 != 0:
                binary_data.append(0)
            
            bin_chunk_length = len(binary_data)
            f.write(bin_chunk_length.to_bytes(4, byteorder='little'))
            f.write(b'BIN\x00')
            f.write(binary_data)
            
            # Update total length
            total_length = f.tell()
            f.seek(total_length_pos)
            f.write(total_length.to_bytes(4, byteorder='little'))
    
    finally:
        # Clean up temporary file
        if os.path.exists(tmp_path):
            os.unlink(tmp_path)


def export_urdf(
    mesh_parts,
    unique_part_ids,
    motion_hierarchy,
    is_part_revolute,
    is_part_prismatic,
    revolute_plucker,
    revolute_range,
    prismatic_axis,
    prismatic_range,
    output_path,
    name="robot",
    link_names=None,
):
    urdf_dir = os.path.dirname(output_path)
    os.makedirs(urdf_dir, exist_ok=True)
    mesh_dir = os.path.abspath(os.path.join(urdf_dir, "meshes"))
    os.makedirs(mesh_dir, exist_ok=True)
    
    # Identify parents and children
    unique_part_ids_set = set(unique_part_ids)
    parent_map = {}
    children_map = {pid: [] for pid in unique_part_ids}
    for p, c in motion_hierarchy:
        # Filter out hierarchy edges where parts don't exist in the mesh
        if p not in unique_part_ids_set or c not in unique_part_ids_set:
            continue
            
        parent_map[c] = p
        if p in children_map:
            children_map[p].append(c)
        else:
            children_map[p] = [c]

    # Find roots
    roots = []
    for pid in unique_part_ids:
        if pid not in parent_map:
            roots.append(pid)
    resolved_link_names = _resolve_link_names(unique_part_ids, link_names=link_names)
            
    # Determine local frame origins for each link (in World Coordinates)
    link_origins_world = {}
    
    for i, pid in enumerate(unique_part_ids):
        if pid in roots:
            link_origins_world[pid] = np.zeros(3)
        elif is_part_revolute[pid]:
            # Revolute: Origin at point on axis
            axis, point = plucker_to_axis_point(revolute_plucker[pid])
            link_origins_world[pid] = point
        elif is_part_prismatic[pid]:
            # Prismatic: Origin at Centroid of mesh
            link_origins_world[pid] = mesh_parts[i].vertices.mean(axis=0)
        else:
            # Fixed/Other
            link_origins_world[pid] = mesh_parts[i].vertices.mean(axis=0)

    # Prepare URDF string
    urdf_lines = []
    urdf_lines.append('<?xml version="1.0"?>')
    urdf_lines.append(f'<robot name="{name}">')
    
    # Process each part
    for i, pid in enumerate(unique_part_ids):
        mesh = mesh_parts[i]
        origin = link_origins_world[pid]
        
        # Save mesh (centered at local origin). Visual geometry uses GLB so
        # embedded textures/materials survive in the URDF package; collision
        # geometry remains OBJ for broad simulator compatibility.
        mesh_local = mesh.copy()
        mesh_local.vertices -= origin
        
        visual_mesh_filename = f"part_{pid}.glb"
        visual_mesh_path = os.path.join(mesh_dir, visual_mesh_filename)
        mesh_local.export(visual_mesh_path)

        collision_mesh_filename = f"part_{pid}_collision.obj"
        collision_mesh_path = os.path.join(mesh_dir, collision_mesh_filename)
        mesh_local.export(collision_mesh_path)
        
        link_name = resolved_link_names[int(pid)]
        
        urdf_lines.append(f'  <link name="{link_name}">')
        urdf_lines.append('    <visual>')
        urdf_lines.append('      <origin xyz="0 0 0" rpy="0 0 0"/>')
        urdf_lines.append('      <geometry>')
        urdf_lines.append(f'        <mesh filename="./meshes/{visual_mesh_filename}"/>')
        urdf_lines.append('      </geometry>')
        if not _mesh_has_embedded_visual_material(mesh_local):
            urdf_lines.append(f'      <material name="material_{pid}">')
            urdf_lines.append('        <color rgba="0.8 0.8 0.8 1.0"/>')
            urdf_lines.append('      </material>')
        urdf_lines.append('    </visual>')
        urdf_lines.append('    <collision>')
        urdf_lines.append('      <origin xyz="0 0 0" rpy="0 0 0"/>')
        urdf_lines.append('      <geometry>')
        urdf_lines.append(f'        <mesh filename="./meshes/{collision_mesh_filename}"/>')
        urdf_lines.append('      </geometry>')
        urdf_lines.append('    </collision>')
        urdf_lines.append('    <inertial>')
        urdf_lines.append('      <mass value="1.0"/>')
        urdf_lines.append('      <inertia ixx="0.1" ixy="0" ixz="0" iyy="0.1" iyz="0" izz="0.1"/>')
        urdf_lines.append('    </inertial>')
        urdf_lines.append('  </link>')
        
    # Joints
    for pid in unique_part_ids:
        if pid in parent_map:
            parent_pid = parent_map[pid]
            child_pid = pid
            
            parent_link = resolved_link_names[int(parent_pid)]
            child_link = resolved_link_names[int(child_pid)]
            joint_name = f"joint_{parent_link}_{child_link}"
            
            p_origin = link_origins_world[parent_pid]
            c_origin = link_origins_world[child_pid]
            offset = c_origin - p_origin
            
            if is_part_revolute[pid]:
                j_type = "revolute"
                axis, _ = plucker_to_axis_point(revolute_plucker[pid])
                axis = axis / (np.linalg.norm(axis) + 1e-6)
                lower, upper = revolute_range[pid]
            elif is_part_prismatic[pid]:
                j_type = "prismatic"
                axis = prismatic_axis[pid]
                axis = axis / (np.linalg.norm(axis) + 1e-6)
                lower, upper = prismatic_range[pid]
            else:
                j_type = "fixed"
                axis = [0, 0, 1]
                lower, upper = 0, 0
                
            urdf_lines.append(f'  <joint name="{joint_name}" type="{j_type}">')
            urdf_lines.append(f'    <parent link="{parent_link}"/>')
            urdf_lines.append(f'    <child link="{child_link}"/>')
            urdf_lines.append(f'    <origin xyz="{offset[0]:.6f} {offset[1]:.6f} {offset[2]:.6f}" rpy="0 0 0"/>')
            if j_type != "fixed":
                urdf_lines.append(f'    <axis xyz="{axis[0]:.6f} {axis[1]:.6f} {axis[2]:.6f}"/>')
                urdf_lines.append(f'    <limit lower="{lower:.6f}" upper="{upper:.6f}" effort="1000" velocity="100"/>')
            urdf_lines.append('  </joint>')
            
    urdf_lines.append('</robot>')
    
    with open(output_path, 'w') as f:
        f.write('\n'.join(urdf_lines))
        
    print(f"Exported URDF to {output_path}")