xiaohang07 commited on
Commit
4475574
·
verified ·
1 Parent(s): 4a6c57a

Upload 7 files

Browse files
Files changed (7) hide show
  1. Voc_prior +130 -0
  2. config.json +29 -0
  3. mattergpt_pipeline.py +40 -0
  4. mattergpt_wrapper.py +68 -0
  5. model.py +312 -0
  6. pytorch_model.pt +3 -0
  7. usage_example.py +70 -0
Voc_prior ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S
2
+ o-o
3
+ He
4
+ Dy
5
+ -o-
6
+ Ne
7
+ +-o
8
+ Re
9
+ Bi
10
+ Cu
11
+ oo+
12
+ 16
13
+ Sc
14
+ --o
15
+ Nd
16
+ Lu
17
+ -+o
18
+ Te
19
+ Si
20
+ o+o
21
+ Er
22
+ 1
23
+ Sr
24
+ Hg
25
+ 3
26
+ oo-
27
+ 8
28
+ Ru
29
+ H
30
+ Mo
31
+ Tc
32
+ 12
33
+ 11
34
+ +oo
35
+ Pb
36
+ 6
37
+ In
38
+ La
39
+ --+
40
+ C
41
+ Sn
42
+ Se
43
+ B
44
+ Ar
45
+ o--
46
+ -o+
47
+ Ga
48
+ ++o
49
+ Rh
50
+ Sm
51
+ Ir
52
+ Li
53
+ Tl
54
+ 18
55
+ I
56
+ Cl
57
+ Ag
58
+ Ba
59
+ Ta
60
+ Ho
61
+ Tb
62
+ As
63
+ -+-
64
+ Gd
65
+ Os
66
+ O
67
+ 15
68
+ ---
69
+ W
70
+ F
71
+ 13
72
+ Pm
73
+ K
74
+ Na
75
+ 9
76
+ Eu
77
+ Ce
78
+ 14
79
+ -++
80
+ 5
81
+ Ge
82
+ Yb
83
+ Al
84
+ Rb
85
+ Pd
86
+ Ni
87
+ Cd
88
+ Hf
89
+ P
90
+ Zn
91
+ Ti
92
+ Nb
93
+ 0
94
+ Pr
95
+ 7
96
+ Mg
97
+ Y
98
+ +-+
99
+ ooo
100
+ Pt
101
+ +--
102
+ 19
103
+ Cs
104
+ N
105
+ -oo
106
+ +o-
107
+ o-+
108
+ Xe
109
+ 4
110
+ o+-
111
+ Tm
112
+ 2
113
+ Cr
114
+ Fe
115
+ +o+
116
+ Zr
117
+ ++-
118
+ Kr
119
+ 10
120
+ +++
121
+ Co
122
+ o++
123
+ Be
124
+ Br
125
+ Mn
126
+ Ca
127
+ Au
128
+ V
129
+ Sb
130
+ 17
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "gpt",
3
+ "architectures": [
4
+ "GPT"
5
+ ],
6
+ "vocab_size": 132,
7
+ "block_size": 397,
8
+ "n_layer": 12,
9
+ "n_head": 12,
10
+ "n_embd": 768,
11
+ "num_props": 2,
12
+ "activation_function": "gelu_new",
13
+ "resid_pdrop": 0.1,
14
+ "embd_pdrop": 0.1,
15
+ "attn_pdrop": 0.1,
16
+ "layer_norm_epsilon": 1e-5,
17
+ "initializer_range": 0.02,
18
+ "summary_type": "cls_index",
19
+ "summary_use_proj": true,
20
+ "summary_activation": null,
21
+ "summary_proj_to_labels": true,
22
+ "summary_first_dropout": 0.1,
23
+ "scale_attn_weights": true,
24
+ "use_cache": true,
25
+ "bos_token_id": 130,
26
+ "eos_token_id": 131,
27
+ "lstm": false,
28
+ "lstm_layers": 0
29
+ }
mattergpt_pipeline.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import torch
3
+ from typing import Dict, List, Union
4
+
5
+ class MatterGPTPipeline(Pipeline):
6
+ def __init__(self, model, tokenizer, device=-1):
7
+ super().__init__(model=model, tokenizer=tokenizer, device=device)
8
+
9
+ def _sanitize_parameters(self, **kwargs):
10
+ return {}, {}, {}
11
+
12
+ def preprocess(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]) -> Dict[str, torch.Tensor]:
13
+ if isinstance(inputs, dict):
14
+ inputs = [inputs]
15
+
16
+ conditions = [[input['formation_energy'], input['band_gap']] for input in inputs]
17
+ context = '>'
18
+ x = torch.tensor([self.tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(len(conditions), 1).to(self.device)
19
+ p = torch.tensor(conditions, dtype=torch.float).unsqueeze(1).to(self.device)
20
+
21
+ return {"input_ids": x, "prop": p}
22
+
23
+ def _forward(self, model_inputs):
24
+ return self.model.generate(
25
+ model_inputs["input_ids"],
26
+ prop=model_inputs["prop"],
27
+ max_length=self.model.config.block_size,
28
+ temperature=1.2,
29
+ do_sample=True,
30
+ top_k=0,
31
+ top_p=0.9
32
+ )
33
+
34
+ def postprocess(self, model_outputs):
35
+ return [self.tokenizer.decode(seq.tolist()) for seq in model_outputs]
36
+
37
+ def __call__(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]):
38
+ pre_processed = self.preprocess(inputs)
39
+ model_outputs = self._forward(pre_processed)
40
+ return self.postprocess(model_outputs)
mattergpt_wrapper.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from model import GPT, GPTConfig # Import your original model and config classes
5
+ import json
6
+
7
+ class CustomGPTConfig(PretrainedConfig):
8
+ model_type = "gpt"
9
+
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+ for key, value in kwargs.items():
13
+ setattr(self, key, value)
14
+
15
+ class MatterGPTWrapper(PreTrainedModel):
16
+ config_class = CustomGPTConfig
17
+ base_model_prefix = "gpt"
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.model = GPT(GPTConfig(**config.__dict__))
22
+
23
+ def forward(self, input_ids, attention_mask=None, labels=None, prop=None):
24
+ return self.model(input_ids, targets=labels, prop=prop)
25
+
26
+ def generate(self, input_ids, prop, max_length, num_return_sequences=1, **kwargs):
27
+ steps = max_length - input_ids.shape[1]
28
+ return self.model.sample(input_ids, steps, prop=prop, **kwargs)
29
+
30
+ @classmethod
31
+ def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs):
32
+ config_file = f"{pretrained_model_path}/config.json"
33
+ with open(config_file, 'r') as f:
34
+ config_dict = json.load(f)
35
+
36
+ config = CustomGPTConfig(**config_dict)
37
+
38
+ model = cls(config)
39
+
40
+ state_dict = torch.load(f"{pretrained_model_path}/pytorch_model.pt", map_location="cpu")
41
+ model.model.load_state_dict(state_dict)
42
+
43
+ return model
44
+
45
+ def save_pretrained(self, save_directory):
46
+ self.config.save_pretrained(save_directory)
47
+ torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.pt")
48
+
49
+ class SimpleTokenizer:
50
+ def __init__(self, vocab_file):
51
+ with open(vocab_file, 'r') as f:
52
+ self.vocab = f.read().splitlines()
53
+ self.vocab = sorted(set(self.vocab + ['<', '>']))
54
+ self.stoi = {ch: i for i, ch in enumerate(self.vocab)}
55
+ self.itos = {i: ch for i, ch in enumerate(self.vocab)}
56
+
57
+ def encode(self, text):
58
+ return [self.stoi[token] for token in text.split()]
59
+
60
+ def decode(self, ids):
61
+ return " ".join([self.itos[int(i)] for i in ids if i in self.itos]).replace("<", "").strip()
62
+
63
+ def __call__(self, text, return_tensors=None):
64
+ encoded = self.encode(text)
65
+ if return_tensors == 'pt':
66
+ import torch
67
+ return {'input_ids': torch.tensor([encoded])}
68
+ return {'input_ids': [encoded]}
model.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Yan Chen 2023.10
3
+ # yanchen@xjtu.edu.com
4
+ """
5
+ GPT model:
6
+ - the initial stem consists of a combination of token encoding and a positional encoding
7
+ - the meat of it is a uniform sequence of Transformer blocks
8
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
9
+ - all blocks feed into a central residual pathway similar to resnets
10
+ - the final decoder is a linear projection into a vanilla Softmax classifier
11
+ """
12
+
13
+ import math,json
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class GPTConfig:
19
+ """ base GPT config, params common to all GPT versions """
20
+ embd_pdrop = 0.1
21
+ resid_pdrop = 0.1
22
+ attn_pdrop = 0.1
23
+
24
+ def __init__(self, vocab_size, block_size, **kwargs):
25
+ self.vocab_size = vocab_size
26
+ self.block_size = block_size
27
+ for k,v in kwargs.items():
28
+ setattr(self, k, v)
29
+
30
+ class GPT1Config(GPTConfig):
31
+ """ GPT-1 like network roughly 125M params """
32
+ n_layer = 12
33
+ n_head = 12
34
+ n_embd = 768
35
+
36
+ class CausalSelfAttention(nn.Module):
37
+ """
38
+ A vanilla multi-head masked self-attention layer with a projection at the end.
39
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
40
+ explicit implementation here to show that there is nothing too scary here.
41
+ """
42
+
43
+ def __init__(self, config):
44
+ super().__init__()
45
+ assert config.n_embd % config.n_head == 0
46
+ # key, query, value projections for all heads
47
+ self.key = nn.Linear(config.n_embd, config.n_embd)
48
+ self.query = nn.Linear(config.n_embd, config.n_embd)
49
+ self.value = nn.Linear(config.n_embd, config.n_embd)
50
+ # regularization
51
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
52
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
53
+ # output projection
54
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
55
+ # causal mask to ensure that attention is only applied to the left in the input sequence
56
+ num = int(bool(config.num_props))
57
+ # num = 1
58
+ self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num))
59
+ .view(1, 1, config.block_size + num, config.block_size + num))
60
+
61
+ self.n_head = config.n_head
62
+
63
+ def forward(self, x, layer_past=None):
64
+ B, T, C = x.size()
65
+
66
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
67
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
69
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
70
+
71
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
72
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
73
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
74
+ att = F.softmax(att, dim=-1)
75
+ attn_save = att
76
+ att = self.attn_drop(att)
77
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
78
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
79
+
80
+ # output projection
81
+ y = self.resid_drop(self.proj(y))
82
+ return y, attn_save
83
+
84
+ class Block(nn.Module):
85
+ """ an unassuming Transformer block """
86
+
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ self.ln1 = nn.LayerNorm(config.n_embd)
90
+ self.ln2 = nn.LayerNorm(config.n_embd)
91
+ self.attn = CausalSelfAttention(config)
92
+ self.mlp = nn.Sequential(
93
+ nn.Linear(config.n_embd, 4 * config.n_embd),
94
+ nn.GELU(),
95
+ nn.Linear(4 * config.n_embd, config.n_embd),
96
+ nn.Dropout(config.resid_pdrop),
97
+ )
98
+
99
+ def forward(self, x):
100
+ y, attn = self.attn(self.ln1(x))
101
+ x = x + y
102
+ x = x + self.mlp(self.ln2(x))
103
+ return x, attn
104
+
105
+ class GPT(nn.Module):
106
+ """ the full GPT language model, with a context size of block_size """
107
+
108
+ def __init__(self, config):
109
+ super().__init__()
110
+ #print(json.dumps(config.__dict__, indent=2))
111
+ # input embedding stem
112
+ self.config = config
113
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
114
+ self.type_emb = nn.Embedding(2, config.n_embd)
115
+ if config.num_props:
116
+ self.prop_nn = nn.Linear(config.num_props, config.n_embd)
117
+
118
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
119
+ self.drop = nn.Dropout(config.embd_pdrop)
120
+ # transformer
121
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
122
+ # decoder head
123
+ self.ln_f = nn.LayerNorm(config.n_embd)
124
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
125
+
126
+ self.block_size = config.block_size
127
+
128
+ if config.lstm:
129
+ self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False)
130
+ self.apply(self._init_weights)
131
+
132
+ #logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
133
+
134
+ def get_block_size(self):
135
+ return self.block_size
136
+
137
+ def _init_weights(self, module):
138
+ if isinstance(module, (nn.Linear, nn.Embedding)):
139
+ module.weight.data.normal_(mean=0.0, std=0.02)
140
+ if isinstance(module, nn.Linear) and module.bias is not None:
141
+ module.bias.data.zero_()
142
+ elif isinstance(module, nn.LayerNorm):
143
+ module.bias.data.zero_()
144
+ module.weight.data.fill_(1.0)
145
+
146
+ def configure_optimizers(self, train_config):
147
+ """
148
+ This long function is unfortunately doing something very simple and is being very defensive:
149
+ We are separating out all parameters of the model into two buckets: those that will experience
150
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
151
+ We are then returning the PyTorch optimizer object.
152
+ """
153
+
154
+ # separate out all parameters to those that will and won't experience regularizing weight decay
155
+ decay = set()
156
+ no_decay = set()
157
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.LSTM)
158
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
159
+ for mn, m in self.named_modules():
160
+ for pn, p in m.named_parameters():
161
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
162
+
163
+ if pn.endswith('bias') or ('bias' in pn):
164
+ # all biases will not be decayed
165
+ no_decay.add(fpn)
166
+ elif (pn.endswith('weight') or ('weight' in pn)) and isinstance(m, whitelist_weight_modules):
167
+ # weights of whitelist modules will be weight decayed
168
+ decay.add(fpn)
169
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
170
+ # weights of blacklist modules will NOT be weight decayed
171
+ no_decay.add(fpn)
172
+
173
+ # special case the position embedding parameter in the root GPT module as not decayed
174
+ no_decay.add('pos_emb')
175
+
176
+ # validate that we considered every parameter
177
+ param_dict = {pn: p for pn, p in self.named_parameters()}
178
+ inter_params = decay & no_decay
179
+ union_params = decay | no_decay
180
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
181
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
182
+ % (str(param_dict.keys() - union_params), )
183
+
184
+ # create the pytorch optimizer object
185
+ optim_groups = [
186
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
187
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
188
+ ]
189
+ optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
190
+ return optimizer
191
+
192
+ def forward(self, idx, targets=None, prop = None):
193
+ b, t = idx.size()
194
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
195
+
196
+ if self.config.num_props:
197
+ assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector"
198
+
199
+ # forward the GPT model
200
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
201
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
202
+ type_embeddings = self.type_emb(torch.ones((b,t), dtype = torch.long, device = idx.device))
203
+ x = self.drop(token_embeddings + position_embeddings + type_embeddings)
204
+
205
+ embed = x
206
+
207
+ if self.config.num_props:
208
+ type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
209
+ if prop.ndim == 2:
210
+ p = self.prop_nn(prop.unsqueeze(1)) # for single property
211
+ else:
212
+ p = self.prop_nn(prop) # for multiproperty
213
+ p += type_embd
214
+ x = torch.cat([p, x], 1)
215
+
216
+ # x = self.blocks(x)
217
+ attn_maps = []
218
+
219
+ for layer in self.blocks:
220
+ x, attn = layer(x)
221
+ attn_maps.append(attn)
222
+
223
+ x = self.ln_f(x)
224
+ logits = self.head(x)
225
+
226
+ if self.config.num_props:
227
+ num = int(bool(self.config.num_props))
228
+ else:
229
+ num = 0
230
+
231
+ logits = logits[:, num:, :]
232
+
233
+ # if we are given some desired targets also calculate the loss
234
+ loss = None
235
+ if targets is not None:
236
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1))
237
+
238
+ return logits, loss, attn_maps, embed # (num_layers, batch_size, num_heads, max_seq_len, max_seq_len)
239
+
240
+
241
+ @torch.no_grad()
242
+ def sample(self, x, steps, temperature=1.0, do_sample=False, top_k=None, top_p=None, prop=None):
243
+ """
244
+ Take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
245
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
246
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
247
+ of block_size, unlike an RNN that has an infinite context window.
248
+
249
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
250
+ """
251
+ #model.eval()
252
+
253
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
254
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
255
+ Args:
256
+ logits: logits distribution shape (batch size x vocabulary size)
257
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
258
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
259
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
260
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
261
+ """
262
+ top_k = min(top_k, logits.size(-1)) # Safety check
263
+ if top_k > 0:
264
+ # Remove all tokens with a probability less than the last token of the top-k
265
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
266
+ logits[indices_to_remove] = filter_value
267
+
268
+ if top_p > 0.0:
269
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
270
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
271
+
272
+ # Remove tokens with cumulative probability above the threshold
273
+ sorted_indices_to_remove = cumulative_probs > top_p
274
+ # Shift the indices to the right to keep also the first token above the threshold
275
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
276
+ sorted_indices_to_remove[..., 0] = 0
277
+
278
+ # scatter sorted tensors to original indexing
279
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
280
+ logits[indices_to_remove] = filter_value
281
+ return logits
282
+
283
+
284
+ for k in range(steps):
285
+ x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:] # crop context if needed
286
+
287
+ # forward the model to get the logits for the index in the sequence
288
+ logits, _, _, _ = self(x_cond, prop = prop) # for sampling, no target
289
+
290
+ # pluck the logits at the final step and scale by desired temperature
291
+ logits = logits[:, -1, :] / temperature
292
+
293
+ # optionally crop the logits to only the top k options OR using nucleus (top-p) filtering
294
+ #if top_k is not None:
295
+ # v, _ = torch.topk(logits, top_k)
296
+ # logits[logits < v[:, [-1]]] = -float('Inf')
297
+ logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k)
298
+
299
+
300
+ # apply softmax to convert logits to (normalized) probabilities
301
+ probs = F.softmax(logits, dim=-1)
302
+
303
+ # sample from the distribution or take the most likely
304
+ if do_sample:
305
+ x_next = torch.multinomial(probs, num_samples=1)
306
+ else:
307
+ _, x_next = torch.topk(probs, k=1, dim=-1)
308
+
309
+ # append sampled index to the running sequence and continue
310
+ x = torch.cat((x, x_next), dim=1)
311
+
312
+ return x[:, 1:]
pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d0071fa7c05449273bfaf60357e2c4f9525c8b8f47e6d313856160312a72b21
3
+ size 349946009
usage_example.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
2
+ import torch
3
+ from tqdm import tqdm
4
+ import os
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Load the model
11
+ model_path = "./" # Directory containing config.json and pytorch_model.bin
12
+ if not os.path.exists(os.path.join(model_path, "config.json")):
13
+ raise FileNotFoundError(f"Config file not found in {model_path}")
14
+ if not os.path.exists(os.path.join(model_path, "pytorch_model.pt")):
15
+ raise FileNotFoundError(f"Model weights not found in {model_path}")
16
+
17
+ model = MatterGPTWrapper.from_pretrained(model_path)
18
+ model.to('cuda' if torch.cuda.is_available() else 'cpu')
19
+ logger.info(f"Model loaded from {model_path}")
20
+
21
+ # Load the tokenizer
22
+ tokenizer_path = "Voc_prior"
23
+ if not os.path.exists(tokenizer_path):
24
+ raise FileNotFoundError(f"Tokenizer vocabulary file not found at {tokenizer_path}")
25
+ tokenizer = SimpleTokenizer(tokenizer_path)
26
+ logger.info(f"Tokenizer loaded from {tokenizer_path}")
27
+
28
+ # Function to generate a single sequence
29
+ def generate_single(condition):
30
+ context = '>'
31
+ x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].to(model.device)
32
+ p = torch.tensor([condition]).unsqueeze(1).to(model.device)
33
+
34
+ generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9)
35
+ return tokenizer.decode(generated[0].tolist())
36
+
37
+ # Function to generate multiple sequences
38
+ def generate_multiple(condition, num_sequences, batch_size=32):
39
+ all_sequences = []
40
+ for _ in tqdm(range(0, num_sequences, batch_size)):
41
+ current_batch_size = min(batch_size, num_sequences - len(all_sequences))
42
+ context = '>'
43
+ x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(current_batch_size, 1).to(model.device)
44
+ p = torch.tensor([condition]).repeat(current_batch_size, 1).unsqueeze(1).to(model.device)
45
+
46
+ generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9)
47
+ all_sequences.extend([tokenizer.decode(seq.tolist()) for seq in generated])
48
+
49
+ if len(all_sequences) >= num_sequences:
50
+ break
51
+
52
+ return all_sequences[:num_sequences]
53
+
54
+ # Example usage
55
+ condition = [-1.0, 2.0] # eform and bandgap
56
+
57
+ # Generate a single sequence
58
+ logger.info("Generating a single sequence:")
59
+ single_sequence = generate_single(condition)
60
+ print(single_sequence)
61
+ print()
62
+
63
+ # Generate multiple sequences
64
+ num_sequences = 10
65
+ logger.info(f"Generating {num_sequences} sequences:")
66
+ multiple_sequences = generate_multiple(condition, num_sequences)
67
+ for i, seq in enumerate(multiple_sequences, 1):
68
+ print(seq)
69
+
70
+ logger.info("Generation complete")