Raildart commited on
Commit
3a595fe
·
verified ·
1 Parent(s): 26e38da

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +143 -1
README.md CHANGED
@@ -7,4 +7,146 @@ tags:
7
  - graph-matching
8
  ---
9
 
10
- # GMT(Graph-Matching-Transformer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  - graph-matching
8
  ---
9
 
10
+ # GMT: Graph Matching Transformer
11
+
12
+ **GMT** (Graph Matching Transformer) is a PyTorch-based framework for matching and aligning 2D curves (graphs) using rich geometric embeddings and a cross-attention Transformer architecture. It supports four model variants—`tiny`, `small`, `medium`, and `large`—to scale computational complexity and capacity.
13
+
14
+ ---
15
+
16
+ ## Key Features
17
+
18
+ - **Multi-Geometry Support**: Generates and processes sinusoids, circles, ellipses, and random polylines.
19
+ - **Curvature & Ray Embeddings**: Computes curvature, ray distances, incidence angles, and hit flags for each point.
20
+ - **Index & Initial Shift Embedding**: Includes normalized index, curvature, and initial displacement as features.
21
+ - **Cross-Attention Transformer**: Two-stream self-attention on target & baseline, followed by cross-attention for fine-grained alignment.
22
+ - **Variants**: Four predefined configurations (`tiny`, `small`, `medium`, `large`) with adjustable `d_model`, depth, and feed-forward dimensions.
23
+ - **Metal/CUDA/CPU**: Auto-selects MPS (Apple Silicon), CUDA, or CPU device.
24
+ - **Visualizations**: Built-in training loss curves, inference progression plots, and error distribution histograms.
25
+
26
+ ---
27
+
28
+ ## Repository Structure
29
+
30
+ ```text
31
+ weights/ # Weights folder
32
+ README.md
33
+ train.py # Entry-point for training all variants
34
+ infer.py # CLI for inference and mapping extraction
35
+ gmt/ # Core package
36
+ __init__.py
37
+ variants.py # Model configurations
38
+ utils.py # Geometry & resampling utilities
39
+ embeddings.py # Ray-segment embedding functions
40
+ dataset.py # ThreadedRayDataset & helpers
41
+ model.py # Transformer definitions
42
+ trainer.py # Training loop and checkpointing
43
+ experiment.ipynb # Jupyter notebook demo
44
+ LICENSE
45
+ requirements.txt # Python dependencies
46
+ ```
47
+
48
+ ---
49
+
50
+ ## Installation
51
+
52
+ ```bash
53
+ # Clone repository
54
+ git clone https://github.com/raildart/gmt.git
55
+ cd gmt
56
+
57
+ # (Optional) Create virtual environment
58
+ python -m venv .venv
59
+ source .venv/bin/activate # or .venv\Scripts\activate on Windows
60
+
61
+ # Install dependencies
62
+ pip install -r requirements.txt
63
+ ```
64
+
65
+ ---
66
+
67
+ ## Quick Start
68
+
69
+ ### Training All Variants
70
+
71
+ ```bash
72
+ python train.py \
73
+ --epochs 30 \
74
+ --batch_size 64 \
75
+ --lr 5e-5
76
+ ```
77
+
78
+ This will train `tiny`, `small`, `medium`, and `large` sequentially and save checkpoints as `GMT_<variant>.pth`.
79
+
80
+ ### Running Inference with External Geometries
81
+
82
+ ```bash
83
+ python infer.py \
84
+ --variant medium \
85
+ --external path/to/geoms.npz \
86
+ --samples 5 \
87
+ --batch_size 16 \
88
+ --save
89
+ ```
90
+
91
+ This loads your own `.npz` with `baseline` and `target` arrays, runs the model, plots 5 sample alignments, and saves `mappings_medium.npz`.
92
+
93
+ ---
94
+
95
+ ## Model Variants & Performance
96
+
97
+ Below is a summary of each variant’s architecture along with its final test MSE (mean squared error). Replace the placeholder MSE values with your actual results.
98
+
99
+ | Variant | d_model | Layers | FF Dim | Dropout | Test MSE |
100
+ | ------- | ------: | -----: | -----: | ------: | -------: |
101
+ | tiny | 128 | 2 | 256 | 0.10 | 0.0034 |
102
+ | small | 256 | 3 | 512 | 0.15 | 0.0028 |
103
+ | medium | 512 | 4 | 1024 | 0.20 | 0.0026 |
104
+ | large | 768 | 5 | 1536 | 0.20 | X |
105
+
106
+ ### Mean Squared Error (MSE)
107
+
108
+ The **Mean Squared Error (MSE)** is our primary training and evaluation metric. For a single predicted sequence $\hat{\mathbf{y}} = [\hat{y}_1, \hat{y}_2, \dots, \hat{y}_N]$ and its ground-truth sequence $\mathbf{y} = [y_1, y_2, \dots, y_N]$, the MSE is computed as:
109
+
110
+ $$
111
+ \mathrm{MSE}(\mathbf{y}, \hat{\mathbf{y}}) \;=\; \frac{1}{N} \sum_{i=1}^{N} \bigl(y_i - \hat{y}_i\bigr)^{2}.
112
+ $$
113
+
114
+ In our setting, each sequence consists of 2-D displacements for $N$ resampled points, so we actually average over both dimensions:
115
+
116
+ $$
117
+ \mathrm{MSE} = \frac{1}{N}\sum_{i=1}^{N}\Bigl[(\Delta x_i - \widehat{\Delta x}_i)^2 + (\Delta y_i - \widehat{\Delta y}_i)^2\Bigr].
118
+ $$
119
+
120
+ During training, we report the **batch-averaged** MSE each epoch, and at the end we compute the **dataset-wide** MSE by averaging over all samples. Lower MSE indicates that the model’s predicted alignment shifts more closely match the true geometric offsets.
121
+
122
+ ---
123
+
124
+ ## API Usage
125
+
126
+ ```python
127
+ from gmt.dataset import ThreadedRayDataset
128
+ from gmt.model import ComplexCrossTransformer
129
+ from gmt.trainer import train
130
+ from gmt.variants import define_variants
131
+
132
+ # Create dataset
133
+ ds = ThreadedRayDataset(num_samples=5000, max_workers=8)
134
+ feat_dim = ds.tgt_feats.shape[-1]
135
+
136
+ # Choose a variant
137
+ variant = 'medium'
138
+ model = ComplexCrossTransformer(tgt_dim=feat_dim, base_dim=3, variant=variant)
139
+
140
+ # Train
141
+ dtrained_model = train(ds, model, variant=variant, epochs=20, batch_size=64, lr=5e-5)
142
+ ```
143
+
144
+ ## GITHUB
145
+
146
+ https://github.com/raildart/GMT
147
+
148
+ ---
149
+
150
+ ## License
151
+
152
+ This project is licensed under the [MIT License](LICENSE).