n0w0f commited on
Commit
283f249
ยท
verified ยท
1 Parent(s): 7949a14

Update README for v2: NL queries, 1024 ctx, LaCLIP architecture

Browse files
Files changed (1) hide show
  1. README.md +158 -138
README.md CHANGED
@@ -1,111 +1,139 @@
1
- # MatText Aligned Embeddings: Multi-Modal Material Retrieval
2
 
3
- **A CLIP-style multi-modal embedding model that aligns 10 different material text representations into a shared 128-d vector space for cross-modal retrieval.**
4
 
5
- Query with *any* modality (composition, CIF, SLICES, natural language, z-matrix...) โ†’ retrieve materials with similar properties across *all* modalities.
 
 
 
 
 
 
 
 
 
 
6
 
7
  ## ๐Ÿ—๏ธ Architecture
8
 
9
  ```
10
- โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
11
- โ”‚ MatTextEncoder โ”‚
12
- โ”‚ โ”‚
13
- โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
14
- โ”‚ โ”‚ Shared Backbone: ModernBERT-base (150M params) โ”‚ โ”‚
15
- โ”‚ โ”‚ - 8192 token context window (handles long CIFs) โ”‚ โ”‚
16
- โ”‚ โ”‚ - Mean pooling โ†’ 768-d representation โ”‚ โ”‚
17
- โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
18
- โ”‚ โ”‚ โ”‚
19
- โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
20
- โ”‚ โ–ผ โ–ผ โ–ผ โ”‚
21
- โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
22
- โ”‚ โ”‚ Projection โ”‚ โ”‚ Projection โ”‚ โ”‚ Projection โ”‚ ... โ”‚
23
- โ”‚ โ”‚ composition โ”‚ โ”‚ cif_sym โ”‚ โ”‚ slices โ”‚ โ”‚
24
- โ”‚ โ”‚ 768โ†’768โ†’128 โ”‚ โ”‚ 768โ†’768โ†’128 โ”‚ โ”‚ 768โ†’768โ†’128 โ”‚ โ”‚
25
- โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
26
- โ”‚ โ–ผ โ–ผ โ–ผ โ”‚
27
- โ”‚ 128-d L2-norm 128-d L2-norm 128-d L2-norm โ”‚
28
- โ”‚ โ”‚
29
- โ”‚ โ”€โ”€โ”€โ”€ Shared Embedding Space โ”€โ”€โ”€โ”€ โ”‚
30
- โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
 
 
31
  ```
32
 
33
- ### Key Design Decisions
34
-
35
- | Decision | Choice | Rationale |
36
- |----------|--------|-----------|
37
- | Backbone | ModernBERT-base | 8192 ctx handles long CIFs; fast RoPE attention |
38
- | Projection | 2-layer MLP per modality | MultiMat recipe: modality-specific heads preserve specialization |
39
- | Embedding dim | 128 | Standard for contrastive learning; compact for FAISS |
40
- | Loss | AllPairsCLIP + Property-MSE | Aligns all N(N-1)/2 modality pairs; property regularization |
41
- | Temperature | Learnable (init 0.07) | CLIP standard; learned ฯ„ improves convergence |
42
-
43
- ## ๐Ÿ“Š Modalities Supported
44
-
45
- | Modality | Column | Example | Query Type |
46
- |----------|--------|---------|------------|
47
- | Composition | `composition` | `Fe2O3` | "Find iron oxides" |
48
- | Atom Sequence | `atom_sequences` | `Fe Fe Fe O O O` | Element lists |
49
- | CIF (symmetrized) | `cif_symmetrized` | Full CIF text | Paste CIF data |
50
- | CIF (P1) | `cif_p1` | Full CIF in P1 | Paste CIF data |
51
- | Z-matrix | `zmatrix` | `Fe\nO 1 2.0\nO 1 2.0 2 90` | Internal coords |
52
- | Atom Seq++ | `atom_sequences_plusplus` | `Fe O 3.57 3.57 90 90` | Elements + lattice |
53
- | SLICES | `slices` | `Fe O 0 1 o o o` | SLICES encoding |
54
- | Crystal Text (LLM) | `crystal_text_llm` | `3.6 3.6 3.6\n90 90 90\nFe...` | Gruver format |
55
- | Local Environment | `local_env` | SMILES-like env | Local bonding |
56
- | Natural Language | `robocrys_rep` | "FeO crystallizes in..." | Plain English |
57
- | **Property Query** | property text | "bandgap: 1.5 eV" | Property search |
58
 
59
- ## ๐Ÿงช Training Recipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- Based on three key papers:
 
 
 
 
 
62
 
63
- 1. **MultiMat** (AllPairsCLIP, [arxiv:2312.00111](https://arxiv.org/abs/2312.00111)): Sum of symmetric InfoNCE over all modality pairs
64
- 2. **MatExpert** ([arxiv:2410.21317](https://arxiv.org/abs/2410.21317)): Propertyโ†”structure contrastive alignment
65
- 3. **CrystalCLR** ([arxiv:2211.13408](https://arxiv.org/abs/2211.13408)): Composition similarity loss
66
- 4. **SupReMix** ([arxiv:2309.16633](https://arxiv.org/abs/2309.16633)): Property-label-aware soft contrastive
 
67
 
68
  ### Two-Phase Training
69
 
70
- **Phase 1 โ€” Multi-modal alignment** (pretrain100k_v2, 50k samples):
71
- - AllPairsCLIP loss across all 10 modalities
72
- - Random modality sampling (4/10 per step) for VRAM efficiency
73
- - Each step aligns C(4,2)=6 modality pairs
74
 
75
- **Phase 2 โ€” Property-conditioned alignment** (bandgap + form_energy, 50k samples):
76
- - Same CLIP loss + property similarity MSE loss
77
- - Property text "composition: Fe2O3 | bandgap: 2.1000" aligned with structure representations
78
- - Materials with similar property values cluster in embedding space
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  ### Hyperparameters
81
 
82
- ```
83
  encoder: answerdotai/ModernBERT-base
84
  embed_dim: 128
85
- max_length: 512 tokens
86
- batch_size: 32 ร— 8 grad_accum = 256 effective
87
- learning_rate: 2e-5 (cosine decay, 10% warmup)
88
  temperature: learnable (init 0.07)
89
  epochs: 3 per phase
90
  optimizer: AdamW (weight_decay=0.01)
91
- fp16: True
92
  gradient_checkpointing: True
 
93
  ```
94
 
95
  ## ๐Ÿš€ Quick Start
96
 
97
- ### Training
98
 
99
  ```bash
100
- pip install torch transformers datasets faiss-cpu huggingface_hub trackio
101
 
102
- # Local GPU
103
- python train_mattext_embeddings.py
104
 
105
- # HF Jobs (recommended: a10g-large, 24GB VRAM)
106
- # Set timeout to 6h
107
  ```
108
 
 
 
 
 
 
109
  ### Inference & Search
110
 
111
  ```python
@@ -113,119 +141,111 @@ import torch
113
  import faiss
114
  import json
115
  import numpy as np
116
- from transformers import AutoModel, AutoTokenizer
117
-
118
- # Load model
119
  from train_mattext_embeddings import MatTextEncoder, Config, search_vector_db
120
 
 
121
  config = Config()
122
  config.device = "cuda" if torch.cuda.is_available() else "cpu"
123
-
124
  model = MatTextEncoder(config)
125
  model.load_state_dict(torch.load("mattext-embeddings/model.pt", map_location=config.device))
126
- model = model.to(config.device)
127
- model.eval()
128
-
129
  tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
130
 
131
  # Load FAISS indices
132
  indices = {}
133
- for mod in ["composition", "crystal_text_llm", "slices", "cif_symmetrized"]:
134
  index = faiss.read_index(f"mattext-embeddings/faiss/{mod}.index")
135
  with open(f"mattext-embeddings/faiss/{mod}_metadata.json") as f:
136
  metadata = json.load(f)
137
  indices[mod] = {"index": index, "metadata": metadata}
138
-
139
- # Search!
140
- results = search_vector_db("Fe2O3", "composition", model, tokenizer, indices, config, k=5)
141
- for score, meta in results:
142
- print(f"Score: {score:.4f} | {meta['composition']}")
143
  ```
144
 
145
- ### Cross-Modal Query Examples
146
 
147
  ```python
148
- # Query by composition โ†’ find across all modalities
149
- search_vector_db("SiO2", "composition", model, tokenizer, indices, config)
 
 
 
150
 
151
- # Query by natural language โ†’ find materials
152
- search_vector_db("perovskite with high bandgap", "robocrys_rep", model, tokenizer, indices, config)
 
153
 
154
- # Query by SLICES representation
155
- search_vector_db("Si O 0 1 o o o", "slices", model, tokenizer, indices, config)
156
 
157
- # Query by CIF data
158
- search_vector_db("data_SiO2\n_symmetry P1\n...", "cif_symmetrized", model, tokenizer, indices, config)
159
 
160
- # Property-conditioned query
161
- search_vector_db("composition: Si | bandgap: 1.1200", "property", model, tokenizer, indices, config)
 
 
 
162
  ```
163
 
164
- ## ๐Ÿ”ฌ Evaluation Metrics
 
 
 
 
 
 
 
 
 
 
165
 
166
- Cross-modal Recall@k: for each material, embed in modality A, retrieve in modality B, check if correct match is in top-k.
167
 
168
- | Pair | R@1 | R@5 | R@10 |
169
- |------|-----|-----|------|
170
- | composition โ†’ crystal_text_llm | TBD | TBD | TBD |
171
- | composition โ†’ cif_symmetrized | TBD | TBD | TBD |
172
- | slices โ†’ crystal_text_llm | TBD | TBD | TBD |
173
- | robocrys_rep โ†’ composition | TBD | TBD | TBD |
174
 
175
  *Results populated after training.*
176
 
177
  ## ๐Ÿงฉ Extending: Graph Embeddings
178
 
179
- The architecture supports adding graph neural network (GNN) embeddings:
180
 
181
  ```python
182
- # Add a GNN projection head
183
- from torch_geometric.nn import SchNet, DimeNet # or CGCNN
184
 
185
  class GraphEncoder(nn.Module):
186
  def __init__(self, embed_dim=128):
187
  super().__init__()
188
- self.gnn = SchNet(hidden_channels=256, num_filters=128, num_interactions=6)
189
  self.proj = ModalityProjection(256, embed_dim)
190
 
191
  def forward(self, data):
192
- # data: PyG Data with pos, z (atomic numbers), batch
193
  h = self.gnn(data.z, data.pos, data.batch)
194
  return self.proj(h)
195
 
196
- # Add to MatTextEncoder:
197
- model.graph_encoder = GraphEncoder(config.embed_dim)
198
- model.projections["graph"] = model.graph_encoder.proj
199
-
200
- # Training: treat graph embeddings as another modality in AllPairsCLIP
201
  ```
202
 
203
- For graph embeddings, convert CIF โ†’ PyG Data (using `pymatgen` + `torch_geometric`):
204
- ```python
205
- from pymatgen.core import Structure
206
- from torch_geometric.data import Data
207
-
208
- def cif_to_graph(cif_string, cutoff=5.0):
209
- struct = Structure.from_str(cif_string, fmt="cif")
210
- # Get neighbors within cutoff
211
- neighbors = struct.get_all_neighbors(cutoff)
212
- # Build edge_index, pos, z ...
213
- return Data(z=atomic_numbers, pos=positions, edge_index=edge_index)
214
- ```
215
 
216
  ## ๐Ÿ“š References
217
 
218
- - **MatText**: [arxiv:2406.17295](https://arxiv.org/abs/2406.17295) โ€” Dataset and text representations
219
- - **MultiMat**: [arxiv:2312.00111](https://arxiv.org/abs/2312.00111) โ€” AllPairsCLIP for materials
220
- - **MatExpert**: [arxiv:2410.21317](https://arxiv.org/abs/2410.21317) โ€” Propertyโ†”structure alignment
221
- - **CrystalCLR**: [arxiv:2211.13408](https://arxiv.org/abs/2211.13408) โ€” Contrastive learning for crystals
222
- - **SupReMix**: [arxiv:2309.16633](https://arxiv.org/abs/2309.16633) โ€” Property-aware hard negatives
223
- - **Symile**: [arxiv:2411.01053](https://arxiv.org/abs/2411.01053) โ€” Total-correlation loss for M modalities
 
224
 
225
  ## ๐Ÿ“„ License
226
 
227
  MIT
228
-
229
- ## ๐Ÿ”— Dataset
230
-
231
- [n0w0f/MatText](https://huggingface.co/datasets/n0w0f/MatText) โ€” 100k+ crystal structures in 10 text representations
 
1
+ # MatText Aligned Embeddings v2: Multi-Modal Material Retrieval with Natural Language Queries
2
 
3
+ **A CLIP-style multi-modal embedding model that aligns 10+ material text representations into a shared 128-d vector space. Query with natural language ("oxide with high bandgap"), composition, CIF, SLICES, or any modality โ†’ retrieve matching materials.**
4
 
5
+ ## ๐Ÿ†• v2 Key Features
6
+
7
+ | Feature | v1 | v2 |
8
+ |---------|----|----|
9
+ | Context length | 512 tokens | **1024 tokens** (captures long CIFs) |
10
+ | Natural language queries | โŒ | **โœ… "oxide with high bandgap"** |
11
+ | Property-aware retrieval | Basic | **LaCLIP-style diverse NL descriptions** |
12
+ | GPU optimization | fp16 / 24GB | **bf16 / 80GB A100 optimized** |
13
+ | Effective batch size | 256 | **288** |
14
+ | Modalities per step | 4 | **5** |
15
+ | Flash Attention 2 | โŒ | **โœ… (auto-detect)** |
16
 
17
  ## ๐Ÿ—๏ธ Architecture
18
 
19
  ```
20
+ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
21
+ โ”‚ MatTextEncoder (157M params) โ”‚
22
+ โ”‚ โ”‚
23
+ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
24
+ โ”‚ โ”‚ Shared Backbone: ModernBERT-base (150M params, 8192 ctx) โ”‚ โ”‚
25
+ โ”‚ โ”‚ Mean pooling โ†’ 768-d representation โ”‚ โ”‚
26
+ โ”‚ โ”‚ Gradient checkpointing + bf16 โ”‚ โ”‚
27
+ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
28
+ โ”‚ โ”‚ โ”‚
29
+ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
30
+ โ”‚ โ–ผ โ–ผ โ–ผ โ–ผ โ”‚
31
+ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ” โ”‚
32
+ โ”‚ โ”‚comp โ”‚ โ”‚cif_sym โ”‚ โ”‚nl_property_desc โ”‚ โ”‚property โ”‚ ...ร—12 โ”‚
33
+ โ”‚ โ”‚768โ†’768 โ”‚ โ”‚768โ†’768 โ”‚ โ”‚768โ†’768โ†’128 โ”‚ โ”‚768โ†’768 โ”‚ โ”‚
34
+ โ”‚ โ”‚โ†’128 โ”‚ โ”‚โ†’128 โ”‚ โ”‚"oxide with high โ”‚ โ”‚โ†’128 โ”‚ โ”‚
35
+ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ bandgap" queries โ”‚ โ”‚ โ”‚ โ”‚
36
+ โ”‚ โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
37
+ โ”‚ โ–ผ โ–ผ โ–ผ โ–ผ โ”‚
38
+ โ”‚ 128-d L2 128-d L2 128-d L2 128-d L2 โ”‚
39
+ โ”‚ โ”‚
40
+ โ”‚ โ”€โ”€โ”€โ”€ Shared 128-d Embedding Space โ”€โ”€โ”€โ”€ โ”‚
41
+ โ”‚ (FAISS IndexFlatIP for cosine similarity search) โ”‚
42
+ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
43
  ```
44
 
45
+ ### 12 Projection Heads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ | # | Head | Input | Purpose |
48
+ |---|------|-------|---------|
49
+ | 1 | `composition` | "Fe2O3" | Formula queries |
50
+ | 2 | `atom_sequences` | "Fe Fe O O O" | Element list queries |
51
+ | 3 | `cif_symmetrized` | Full CIF | Paste CIF data |
52
+ | 4 | `cif_p1` | CIF in P1 | P1 space group CIF |
53
+ | 5 | `zmatrix` | Z-matrix coords | Internal coordinates |
54
+ | 6 | `atom_sequences_plusplus` | Elements + lattice | Atom sequence + cell |
55
+ | 7 | `slices` | SLICES encoding | Compact structure encoding |
56
+ | 8 | `crystal_text_llm` | Gruver format | Lattice + coords text |
57
+ | 9 | `local_env` | SMILES-like env | Local bonding environment |
58
+ | 10 | `robocrys_rep` | NL description | "FeO crystallizes in..." |
59
+ | 11 | **`nl_property_description`** | **Free-form NL** | **"oxide with high bandgap"** |
60
+ | 12 | `property` | Structured props | "bandgap: 2.1 eV" |
61
+
62
+ ## ๐Ÿ” How NL Queries Work
63
+
64
+ The key innovation is a **LaCLIP-style** training approach ([arxiv:2305.20088](https://arxiv.org/abs/2305.20088)):
65
 
66
+ 1. **During Phase 2 training**, for each material with known properties (bandgap, formation energy), we generate **diverse natural language descriptions** from templates:
67
+ - `"A wide bandgap oxide suitable for UV applications, bandgap 3.20 eV"`
68
+ - `"TiO2: oxide semiconductor with wide band gap of 3.20 electron volts"`
69
+ - `"This binary oxide (TiO2) exhibits a wide bandgap of approximately 3.20 eV"`
70
+
71
+ 2. These NL descriptions are passed through a **dedicated `nl_property_description` projection head** and aligned with ALL structure modalities via InfoNCE.
72
 
73
+ 3. **At inference**, when you query `"oxide with high bandgap"`, the model maps it through the same NL head into the shared embedding space, and FAISS finds the nearest materials โ€” those that were trained to be close to similar descriptions.
74
+
75
+ This is distinct from `robocrys_rep` (which describes crystal *structure*: "FeO crystallizes in the rock salt structure..."). The NL query head describes *properties* ("wide bandgap oxide").
76
+
77
+ ## ๐Ÿงช Training Recipe
78
 
79
  ### Two-Phase Training
80
 
81
+ **Phase 1 โ€” Multi-modal alignment** (pretrain100k_v2, 60k samples, 3 epochs):
82
+ - AllPairsCLIP loss across 10 modalities
83
+ - Random modality sampling (5/10 per step) โ€” always includes composition + crystal_text_llm
84
+ - Effective batch 288
85
 
86
+ **Phase 2 โ€” Property-conditioned + NL query alignment** (bandgap + formation_energy, 60k samples, 3 epochs):
87
+ - AllPairsCLIP loss (structure modalities)
88
+ - **NL description โ†” structure InfoNCE** (the key NL query loss)
89
+ - Property โ†” composition/crystal_text_llm InfoNCE ([MatExpert](https://arxiv.org/abs/2410.21317))
90
+ - SupReMix-style property similarity MSE ([arxiv:2309.16633](https://arxiv.org/abs/2309.16633))
91
+ - Loss weights: `L = L_clip + 0.3 * L_property + 0.5 * L_nl`
92
+
93
+ ### Based On
94
+
95
+ | Paper | Contribution | ArXiv |
96
+ |-------|-------------|-------|
97
+ | **MultiMat** | AllPairsCLIP loss | [2312.00111](https://arxiv.org/abs/2312.00111) |
98
+ | **MatExpert** | Propertyโ†”structure InfoNCE | [2410.21317](https://arxiv.org/abs/2410.21317) |
99
+ | **LaCLIP** | LLM text augmentation for CLIP | [2305.20088](https://arxiv.org/abs/2305.20088) |
100
+ | **SupReMix** | Property-label-aware soft contrastive | [2309.16633](https://arxiv.org/abs/2309.16633) |
101
+ | **CrystalCLR** | Composition similarity | [2211.13408](https://arxiv.org/abs/2211.13408) |
102
 
103
  ### Hyperparameters
104
 
105
+ ```yaml
106
  encoder: answerdotai/ModernBERT-base
107
  embed_dim: 128
108
+ max_length: 1024 tokens
109
+ batch_size: 48 ร— 6 grad_accum = 288 effective
110
+ learning_rate: 2e-5 (phase 1), 1e-5 (phase 2)
111
  temperature: learnable (init 0.07)
112
  epochs: 3 per phase
113
  optimizer: AdamW (weight_decay=0.01)
114
+ precision: bf16 (A100) / fp16 (T4/V100)
115
  gradient_checkpointing: True
116
+ max_modalities_per_step: 5
117
  ```
118
 
119
  ## ๐Ÿš€ Quick Start
120
 
121
+ ### Training (your GPU)
122
 
123
  ```bash
124
+ pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate
125
 
126
+ # Optional but recommended for A100/H100:
127
+ pip install flash-attn --no-build-isolation
128
 
129
+ python train_mattext_embeddings.py
 
130
  ```
131
 
132
+ The script auto-detects:
133
+ - GPU capability (bf16 for Ampere+, fp16 otherwise)
134
+ - Flash Attention 2 availability
135
+ - CUDA vs CPU
136
+
137
  ### Inference & Search
138
 
139
  ```python
 
141
  import faiss
142
  import json
143
  import numpy as np
144
+ from transformers import AutoTokenizer
 
 
145
  from train_mattext_embeddings import MatTextEncoder, Config, search_vector_db
146
 
147
+ # Load
148
  config = Config()
149
  config.device = "cuda" if torch.cuda.is_available() else "cpu"
 
150
  model = MatTextEncoder(config)
151
  model.load_state_dict(torch.load("mattext-embeddings/model.pt", map_location=config.device))
152
+ model = model.to(config.device).eval()
 
 
153
  tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
154
 
155
  # Load FAISS indices
156
  indices = {}
157
+ for mod in ["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"]:
158
  index = faiss.read_index(f"mattext-embeddings/faiss/{mod}.index")
159
  with open(f"mattext-embeddings/faiss/{mod}_metadata.json") as f:
160
  metadata = json.load(f)
161
  indices[mod] = {"index": index, "metadata": metadata}
 
 
 
 
 
162
  ```
163
 
164
+ ### Query Examples
165
 
166
  ```python
167
+ # ๐Ÿ” Natural language property queries (THE KEY FEATURE)
168
+ search_vector_db("oxide with high bandgap", "nl_property_description", model, tokenizer, indices, config)
169
+ search_vector_db("stable ternary nitride", "nl_property_description", model, tokenizer, indices, config)
170
+ search_vector_db("narrow bandgap semiconductor for IR", "nl_property_description", model, tokenizer, indices, config)
171
+ search_vector_db("metallic binary compound", "nl_property_description", model, tokenizer, indices, config)
172
 
173
+ # ๐Ÿงช Composition queries
174
+ search_vector_db("Fe2O3", "composition", model, tokenizer, indices, config)
175
+ search_vector_db("BaTiO3", "composition", model, tokenizer, indices, config)
176
 
177
+ # ๐Ÿ“– Structure description queries
178
+ search_vector_db("perovskite with octahedral coordination", "robocrys_rep", model, tokenizer, indices, config)
179
 
180
+ # ๐Ÿ“Š Structured property queries
181
+ search_vector_db("composition: TiO2 | bandgap: 3.2000", "property", model, tokenizer, indices, config)
182
 
183
+ # ๐Ÿ”ฌ CIF queries (paste your CIF)
184
+ search_vector_db("data_TiO2\n_symmetry P1\n_cell 4.59 4.59 2.96 90 90 90", "cif_symmetrized", ...)
185
+
186
+ # ๐Ÿงฌ SLICES queries
187
+ search_vector_db("Ti O 0 1 o o o", "slices", model, tokenizer, indices, config)
188
  ```
189
 
190
+ ## ๐Ÿ“Š Evaluation Metrics
191
+
192
+ Cross-modal Recall@k on test set:
193
+
194
+ | Pair | R@1 | R@5 | R@10 | R@20 |
195
+ |------|-----|-----|------|------|
196
+ | composition โ†’ crystal_text_llm | TBD | TBD | TBD | TBD |
197
+ | composition โ†’ cif_symmetrized | TBD | TBD | TBD | TBD |
198
+ | composition โ†’ slices | TBD | TBD | TBD | TBD |
199
+ | slices โ†’ crystal_text_llm | TBD | TBD | TBD | TBD |
200
+ | robocrys_rep โ†’ composition | TBD | TBD | TBD | TBD |
201
 
202
+ NL Query Results:
203
 
204
+ | Query | Top-1 Match | Score |
205
+ |-------|------------|-------|
206
+ | "oxide with high bandgap" | TBD | TBD |
207
+ | "narrow bandgap semiconductor" | TBD | TBD |
208
+ | "stable binary oxide" | TBD | TBD |
 
209
 
210
  *Results populated after training.*
211
 
212
  ## ๐Ÿงฉ Extending: Graph Embeddings
213
 
214
+ The architecture is plug-and-play for new modalities:
215
 
216
  ```python
217
+ # Add a GNN modality
218
+ from torch_geometric.nn import SchNet
219
 
220
  class GraphEncoder(nn.Module):
221
  def __init__(self, embed_dim=128):
222
  super().__init__()
223
+ self.gnn = SchNet(hidden_channels=256)
224
  self.proj = ModalityProjection(256, embed_dim)
225
 
226
  def forward(self, data):
 
227
  h = self.gnn(data.z, data.pos, data.batch)
228
  return self.proj(h)
229
 
230
+ # Register as new modality
231
+ model.projections["graph"] = graph_encoder.proj
232
+ # It gets aligned automatically through AllPairsCLIP
 
 
233
  ```
234
 
235
+ ## ๐Ÿ“ฆ Dataset
236
+
237
+ [n0w0f/MatText](https://huggingface.co/datasets/n0w0f/MatText) โ€” 100k+ crystal structures in 10+ text representations
 
 
 
 
 
 
 
 
 
238
 
239
  ## ๐Ÿ“š References
240
 
241
+ - **MatText**: [arxiv:2406.17295](https://arxiv.org/abs/2406.17295)
242
+ - **MultiMat**: [arxiv:2312.00111](https://arxiv.org/abs/2312.00111)
243
+ - **MatExpert**: [arxiv:2410.21317](https://arxiv.org/abs/2410.21317)
244
+ - **LaCLIP**: [arxiv:2305.20088](https://arxiv.org/abs/2305.20088)
245
+ - **SupReMix**: [arxiv:2309.16633](https://arxiv.org/abs/2309.16633)
246
+ - **CrystalCLR**: [arxiv:2211.13408](https://arxiv.org/abs/2211.13408)
247
+ - **Symile**: [arxiv:2411.01053](https://arxiv.org/abs/2411.01053)
248
 
249
  ## ๐Ÿ“„ License
250
 
251
  MIT