--- license: apache-2.0 tags: - robotics - grasping - graph-neural-network - point-cloud - contact-prediction - morphology - cross-attention --- # Graspmax — GeoMatch & GeoMatch++ **Graspmax** contains two geometry-aware contact prediction models for dexterous robotic grasping, trained on the CMapDataset across 5 robot end-effectors. --- ## Models ### GeoMatch Dual GCN encoder (object + robot surface) → projection layers → 5 autoregressive MLP modules → BCE contact map prediction. > Based on: [*Geometry Matching for Multi-Embodiment Grasping*](https://openreview.net/forum?id=oyWkrG-LD5) (NeurIPS 2024) ### GeoMatch++ Extends GeoMatch with a **morphology encoder** (third GCN over the robot's kinematic-tree graph) and a **DCP-style cross-attention transformer** that fuses object geometry with robot morphology before contact prediction. Pretrained GeoMatch encoders are frozen; only the morphology encoder, transformer, projection heads, and AR modules are trained. > Based on: [*GeoMatch++: Morphology-Aware Grasping via Correspondence Learning*](https://arxiv.org/abs/2412.18998) --- ## Architecture Comparison | Component | GeoMatch | GeoMatch++ | |---|---|---| | Object GCN encoder | 3 layers × 256 → 512 | Same, **frozen** | | Robot surface GCN | 3 layers × 256 → 512 | Same, **frozen** | | Morphology encoder | — | **NEW** GCN(9 → 256×3 → 512) | | Cross-attention | — | **NEW** DCP transformer (512-dim, 4 heads, 1 layer) | | Projection heads | Linear(512→64) × 2 | Same, re-initialised | | AR keypoint modules | 5× MLP | Same, re-initialised | | **Total params** | **~1.9M** | **~6.4M (5.8M trainable)** | --- ## Training Details ### GeoMatch | Setting | Value | |---|---| | Dataset | CMapDataset (ContactDB + YCB) | | Training samples | 41,871 | | End-effectors | EZGripper, Barrett, Robotiq 3-Finger, Allegro, ShadowHand | | Batch size | 256 | | Optimizer | Adam (β₁=0.9, β₂=0.99) | | Learning rate | 1e-4 | | Epochs | 200 | | Hardware | AMD Instinct MI300X (192 GB HBM3), ROCm 6.2.4 | | Training time | 22.18 hours | | Precision | FP32 | | **Final val loss** | **0.435** | ### GeoMatch++ | Setting | Value | |---|---| | Initialisation | Pretrained GeoMatch encoders (frozen) | | Trainable params | ~5.8M | | Batch size | 32 per GPU × 8 GPUs = 256 effective | | Optimizer | Adam (β₁=0.9, β₂=0.99) | | Learning rate | 5e-5 | | Epochs | 150 | | Hardware | 8× AMD Instinct MI300X, ROCm 6.2.4 (DDP) | | Training time | ~2.8 hours | | Precision | FP32 | | **Final val loss** | **0.350** (↓19% vs GeoMatch) | | **Final val accuracy** | **0.940** | ### GeoMatch++ Training Curves | Epoch | Val Loss | Val Accuracy | |---|---|---| | 0 | 0.465 | 0.999 | | 25 | 0.370 | 0.880 | | 50 | — | — | | 89 | 0.362 | 0.902 | | 100 | — | — | | 140 | — | — | | 149 | 0.350 | 0.940 | --- ## Checkpoints ### GeoMatch | File | Epoch | Notes | |---|---|---| | `checkpoint_epoch50.pth` | 50 | Early convergence | | `checkpoint_epoch100.pth` | 100 | Mid-training | | `checkpoint_epoch150.pth` | 150 | Near-converged | | `geomatch_final.pth` | 200 | **Final model (recommended)** | ### GeoMatch++ | File | Epoch | Notes | |---|---|---| | `geomatch_pp_checkpoint_epoch50.pth` | 50 | Early convergence | | `geomatch_pp_checkpoint_epoch100.pth` | 100 | Mid-training | | `geomatch_pp_checkpoint_epoch140.pth` | 140 | Near-converged | | `geomatch_pp_final.pth` | 149 | **Final model (recommended)** | --- ## Usage ### GeoMatch ```python import torch, sys sys.path.append(".") import config from models.geomatch import GeoMatch model = GeoMatch(config).cuda() model.load_state_dict(torch.load("geomatch_final.pth", map_location="cuda")) model.eval() with torch.no_grad(): contact_map, keypoint_probs = model( obj_pc, # [B, 2048, 3] robot_pc, # [B, 6, 3] robot_key_point_idx, # [B, 6] obj_adj, # [B, 2048, 2048] robot_adj, # [B, 6, 6] xyz_prev, # [B, 6, 3] ) # contact_map: [B, 2048, 6, 1] # keypoint_probs: [B, 2048, 5, 1] ``` ### GeoMatch++ ```python import torch, sys sys.path.append(".") import config from models.geomatch_pp import GeoMatchPP model = GeoMatchPP(config).cuda() model.load_state_dict(torch.load("geomatch_pp_final.pth", map_location="cuda")) model.eval() with torch.no_grad(): contact_map, keypoint_probs = model( obj_pc, # [B, 2048, 3] object point cloud robot_pc, # [B, 6, 3] robot surface points robot_key_point_idx, # [B, 6] keypoint indices obj_adj, # [B, 2048, 2048] object adjacency robot_adj, # [B, 6, 6] robot adjacency xyz_prev, # [B, 6, 3] previous keypoint positions morph_features, # [B, 32, 9] morphology node features morph_adj, # [B, 32, 32] morphology adjacency ) # contact_map: [B, 2048, 6, 1] # keypoint_probs: [B, 2048, 5, 1] ``` Morphology graphs (`morph_features`, `morph_adj`) are pre-built per robot using `preprocess_morphology.py` and stored in `gnn_morphology_new.pt`. --- ## Repository Structure ``` models/ geomatch.py # GeoMatch model (dual GCN + AR modules) geomatch_pp.py # GeoMatch++ model (+ morphology encoder + DCP transformer) gnn.py # Graph Convolutional Network mlp.py # MLP building block config.py # Hyperparameters for both models ``` --- ## Citation ```bibtex @inproceedings{geomatch2024, title = {Geometry Matching for Multi-Embodiment Grasping}, booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, year = {2024}, } @article{geomatch_pp2024, title = {GeoMatch++: Morphology-Aware Grasping via Correspondence Learning}, journal = {arXiv preprint arXiv:2412.18998}, year = {2024}, } ``` --- ## License Original GeoMatch code © 2023 DeepMind Technologies Limited, licensed under the Apache License 2.0. GeoMatch++ extension and checkpoints trained by [Dimios45](https://huggingface.co/Dimios45) as part of the Graspmax project.