gbyuvd commited on
Commit
00c497b
·
verified ·
1 Parent(s): 6778326

Upload 7 files

Browse files
Files changed (7) hide show
  1. ChemQ3MTP.py +753 -0
  2. FastChemTokenizerHF.py +769 -0
  3. LICENSE +21 -0
  4. config.json +34 -0
  5. demo_test_mtpresult.ipynb +190 -0
  6. train-withmtp.py +365 -0
  7. train_ppokl_withsa.py +131 -0
ChemQ3MTP.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========================
2
+ # ChemQ3-MTP
3
+ # MODEL COMPONENTS
4
+ # by gbyuvd
5
+ # ========================
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.distributions import Categorical
12
+ from typing import List, Union, Optional, Tuple, Dict, Any
13
+ from transformers import Qwen3Config, Qwen3ForCausalLM, AutoTokenizer
14
+ from rdkit import Chem
15
+ from rdkit.Chem import Descriptors, Lipinski
16
+ import selfies as sf
17
+ from rdkit import RDLogger
18
+ RDLogger.DisableLog('rdApp.*') # suppress all SMILES parse messages
19
+ import json
20
+ from typing import List, Union, Optional, Tuple
21
+ from transformers.tokenization_utils_base import BatchEncoding
22
+ from FastChemTokenizer import FastChemTokenizerSelfies
23
+ import numpy as np
24
+ from collections import Counter
25
+ from rdkit.Chem import Descriptors, Lipinski, rdMolDescriptors
26
+
27
+ # ========================
28
+ # UTILS: SELFIES -> SMILES -> VALIDITY & LIPINSKI
29
+ # ========================
30
+
31
+ def selfies_to_smiles(selfies_str: str) -> str | None:
32
+ """Convert SELFIES string to SMILES, handling tokenizer artifacts."""
33
+ try:
34
+ clean_selfies = selfies_str.replace(" ", "")
35
+ return sf.decoder(clean_selfies)
36
+ except Exception:
37
+ return None
38
+
39
+ def is_valid_smiles(smiles: str) -> bool:
40
+ if not isinstance(smiles, str) or len(smiles.strip()) == 0:
41
+ return False
42
+ return Chem.MolFromSmiles(smiles.strip()) is not None
43
+
44
+ # SA Classifier
45
+ from transformers import pipeline
46
+
47
+ # Optional: lazy load so we don’t reload every time
48
+ _sa_classifier = None
49
+ def get_sa_classifier():
50
+ global _sa_classifier
51
+ if _sa_classifier is None:
52
+ _sa_classifier = pipeline("text-classification", model="gbyuvd/synthaccess-chemselfies")
53
+ return _sa_classifier
54
+
55
+
56
+ def compute_sa_reward(selfies_str: str) -> float:
57
+ """Reward molecules with easy synthetic accessibility (SA)."""
58
+ try:
59
+ classifier = get_sa_classifier()
60
+ result = classifier(selfies_str, truncation=True, max_length=128)[0]
61
+ if result["label"].lower() == "easy":
62
+ return result["score"]
63
+ else:
64
+ return -result["score"] # penalize "Hard"
65
+ except Exception:
66
+ return 0.0
67
+
68
+ # ==========================
69
+ # Reward Components
70
+ # ==========================
71
+ def compute_biological_diversity_score(mol) -> float:
72
+ """Reward molecules with diverse CHONP atoms, normalized to [0,1]."""
73
+ if mol is None:
74
+ return 0.0
75
+ try:
76
+ atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
77
+ atom_counts = Counter(atoms)
78
+ bio_elements = {"C", "H", "O", "N", "P"}
79
+ present_bio_elements = set(atoms) & bio_elements
80
+
81
+ if len(present_bio_elements) < 2:
82
+ return 0.0
83
+
84
+ base_score = 0.3
85
+ diversity_bonus = (len(present_bio_elements) - 2) / 3 * 0.4
86
+
87
+ total_bio_atoms = sum(atom_counts.get(e, 0) for e in present_bio_elements)
88
+ if total_bio_atoms > 0:
89
+ bio_probs = [atom_counts.get(e, 0) / total_bio_atoms for e in present_bio_elements]
90
+ if len(bio_probs) > 1:
91
+ entropy = -sum(p * np.log2(p) for p in bio_probs if p > 0)
92
+ max_entropy = np.log2(len(bio_probs))
93
+ entropy_bonus = (entropy / max_entropy) * 0.3
94
+ else:
95
+ entropy_bonus = 0.0
96
+ else:
97
+ entropy_bonus = 0.0
98
+
99
+ return min(1.0, base_score + diversity_bonus + entropy_bonus)
100
+ except Exception:
101
+ return 0.0
102
+
103
+
104
+ def compute_charge_neutrality_score(mol) -> float:
105
+ """Reward if molecule is globally neutral (formal charge = 0)."""
106
+ if mol is None:
107
+ return 0.0
108
+ try:
109
+ return 1.0 if Chem.rdmolops.GetFormalCharge(mol) == 0 else 0.0
110
+ except Exception:
111
+ return 0.0
112
+
113
+
114
+ def compute_local_charge_penalty(mol) -> float:
115
+ """
116
+ Penalize carbocations/anions.
117
+ Returns 1.0 if no charged atoms, decreases with fraction charged.
118
+ """
119
+ if mol is None:
120
+ return 0.0
121
+ try:
122
+ charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]
123
+ if not charges:
124
+ return 1.0
125
+ charged_atoms = sum(1 for c in charges if c != 0)
126
+ total_atoms = len(charges)
127
+ return max(0.0, 1.0 - (charged_atoms / total_atoms))
128
+ except Exception:
129
+ return 0.0
130
+
131
+
132
+ def compute_enhanced_lipinski_reward(mol) -> float:
133
+ """Soft Lipinski scoring with partial credit."""
134
+ if mol is None:
135
+ return 0.0
136
+ try:
137
+ mw = Descriptors.MolWt(mol)
138
+ logp = Descriptors.MolLogP(mol)
139
+ hbd = Lipinski.NumHDonors(mol)
140
+ hba = Lipinski.NumHAcceptors(mol)
141
+ scores = []
142
+
143
+ # MW
144
+ if 250 <= mw <= 500: scores.append(1.0)
145
+ elif 150 <= mw < 250: scores.append(0.5)
146
+ elif 500 < mw <= 600: scores.append(0.7)
147
+ else: scores.append(0.0)
148
+
149
+ # LogP
150
+ if -1 <= logp <= 5: scores.append(1.0)
151
+ elif -2 <= logp < -1 or 5 < logp <= 6: scores.append(0.5)
152
+ else: scores.append(0.0)
153
+
154
+ # Donors
155
+ scores.append(1.0 if hbd <= 5 else max(0.0, 1.0 - 0.2 * (hbd - 5)))
156
+ # Acceptors
157
+ scores.append(1.0 if hba <= 10 else max(0.0, 1.0 - 0.1 * (hba - 10)))
158
+
159
+ return sum(scores) / len(scores)
160
+ except Exception:
161
+ return 0.0
162
+
163
+
164
+ def compute_structural_complexity_reward(mol) -> float:
165
+ """Reward moderate complexity: 1–3 rings and some flexibility."""
166
+ if mol is None:
167
+ return 0.0
168
+ try:
169
+ ring_count = rdMolDescriptors.CalcNumRings(mol)
170
+ if 1 <= ring_count <= 3: ring_score = 1.0
171
+ elif ring_count == 0: ring_score = 0.3
172
+ elif ring_count <= 5: ring_score = 0.7
173
+ else: ring_score = 0.1
174
+
175
+ rot_bonds = Descriptors.NumRotatableBonds(mol)
176
+ if 2 <= rot_bonds <= 8: flex_score = 1.0
177
+ elif rot_bonds <= 12: flex_score = 0.7
178
+ elif rot_bonds in (0, 1): flex_score = 0.5
179
+ else: flex_score = 0.2
180
+
181
+ return (ring_score + flex_score) / 2
182
+ except Exception:
183
+ return 0.0
184
+
185
+
186
+ # ==========================
187
+ # Unified Reward
188
+ # ==========================
189
+ def compute_comprehensive_reward(selfies_str: str) -> dict[str, float]:
190
+ smiles = selfies_to_smiles(selfies_str)
191
+ mol = Chem.MolFromSmiles(smiles) if smiles else None
192
+
193
+ rewards = {
194
+ "validity": 1.0 if mol is not None else 0.0,
195
+ "biological_diversity": compute_biological_diversity_score(mol),
196
+ "charge_neutrality": compute_charge_neutrality_score(mol),
197
+ "local_charge_penalty": compute_local_charge_penalty(mol),
198
+ "lipinski": compute_enhanced_lipinski_reward(mol),
199
+ "structural_complexity": compute_structural_complexity_reward(mol),
200
+ }
201
+
202
+ if rewards["validity"] == 0:
203
+ rewards["total"] = 0.0
204
+ else:
205
+ weights = {
206
+ "validity": 1.0,
207
+ "biological_diversity": 2.0,
208
+ "charge_neutrality": 1.5,
209
+ "local_charge_penalty": 1.0,
210
+ "lipinski": 1.0,
211
+ "structural_complexity": 0.5,
212
+ }
213
+ weighted_sum = sum(rewards[k] * weights[k] for k in weights)
214
+ rewards["total"] = weighted_sum / sum(weights.values())
215
+
216
+ return rewards
217
+
218
+ def compute_lipinski_reward(mol) -> float:
219
+ if mol is None:
220
+ return 0.0
221
+ try:
222
+ mw = Descriptors.MolWt(mol)
223
+ logp = Descriptors.MolLogP(mol)
224
+ hbd = Lipinski.NumHDonors(mol)
225
+ hba = Lipinski.NumHAcceptors(mol)
226
+ rules = [250 < mw <= 500, logp <= 5, hbd <= 5, hba <= 10] # we dont want too small of fragments
227
+ return sum(rules) / 4.0
228
+ except:
229
+ return 0.0
230
+
231
+ def selfies_to_lipinski_reward(selfies_str: str) -> float:
232
+ """Convert SELFIES to SMILES, then compute Lipinski reward."""
233
+ smiles = selfies_to_smiles(selfies_str)
234
+ if smiles is None:
235
+ return 0.0
236
+ mol = Chem.MolFromSmiles(smiles)
237
+ return compute_lipinski_reward(mol)
238
+
239
+ class MTPHead(nn.Module):
240
+ def __init__(self, hidden_size: int, vocab_size: int, num_future_tokens: int = 3):
241
+ super().__init__()
242
+ self.num_future_tokens = num_future_tokens
243
+ self.vocab_size = vocab_size
244
+ self.prediction_heads = nn.ModuleList([
245
+ nn.Linear(hidden_size, vocab_size, bias=False)
246
+ for _ in range(num_future_tokens)
247
+ ])
248
+ self.position_embeddings = nn.Embedding(num_future_tokens, hidden_size)
249
+ self.layer_norm = nn.LayerNorm(hidden_size)
250
+
251
+ def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
252
+ batch_size, seq_len, hidden_size = hidden_states.shape
253
+ outputs = {}
254
+ for i in range(self.num_future_tokens):
255
+ pos_emb = self.position_embeddings(torch.tensor(i, device=hidden_states.device))
256
+ enhanced_hidden = self.layer_norm(hidden_states + pos_emb)
257
+ logits = self.prediction_heads[i](enhanced_hidden)
258
+ outputs[f'logits_t{i+1}'] = logits
259
+ return outputs
260
+
261
+
262
+ class HorizonLoss(nn.Module):
263
+ def __init__(self, num_future_tokens: int = 3, horizon_weights: Optional[List[float]] = None):
264
+ super().__init__()
265
+ self.num_future_tokens = num_future_tokens
266
+ if horizon_weights is None:
267
+ self.horizon_weights = [0.9 ** i for i in range(num_future_tokens)]
268
+ else:
269
+ self.horizon_weights = horizon_weights
270
+ self.log_weights = nn.Parameter(torch.log(torch.tensor(self.horizon_weights)))
271
+
272
+ def forward(self, mtp_outputs: Dict[str, torch.Tensor],
273
+ input_ids: torch.Tensor,
274
+ attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
275
+ batch_size, seq_len = input_ids.shape
276
+ device = input_ids.device
277
+ weights = F.softmax(self.log_weights, dim=0)
278
+ total_loss = 0.0
279
+ horizon_losses = {}
280
+ for i in range(self.num_future_tokens):
281
+ logits_key = f'logits_t{i+1}'
282
+ if logits_key not in mtp_outputs:
283
+ continue
284
+ logits = mtp_outputs[logits_key]
285
+ shift = i + 1
286
+ if seq_len <= shift:
287
+ continue
288
+ shifted_logits = logits[:, :-shift, :].contiguous()
289
+ shifted_targets = input_ids[:, shift:].contiguous()
290
+ if attention_mask is not None:
291
+ shifted_mask = attention_mask[:, shift:].contiguous()
292
+ mask_expanded = shifted_mask.view(-1)
293
+ valid_indices = mask_expanded == 1
294
+ if valid_indices.sum() == 0:
295
+ continue
296
+ flat_logits = shifted_logits.view(-1, logits.size(-1))[valid_indices]
297
+ flat_targets = shifted_targets.view(-1)[valid_indices]
298
+ else:
299
+ flat_logits = shifted_logits.view(-1, logits.size(-1))
300
+ flat_targets = shifted_targets.view(-1)
301
+ horizon_loss = F.cross_entropy(flat_logits, flat_targets, reduction='mean')
302
+ horizon_losses[f'horizon_loss_t{i+1}'] = horizon_loss
303
+ total_loss += weights[i] * horizon_loss
304
+ return {'loss': total_loss, 'horizon_weights': weights, **horizon_losses}
305
+
306
+
307
+ class ChemQ3MTP(Qwen3ForCausalLM):
308
+ def __init__(self, config, num_future_tokens: int = 3):
309
+ super().__init__(config)
310
+ self.mtp_head = MTPHead(config.hidden_size, config.vocab_size, num_future_tokens)
311
+ self.horizon_loss = HorizonLoss(num_future_tokens=num_future_tokens)
312
+ self.use_mtp_training = True
313
+ self.post_init()
314
+ self.entropy_controller = EnhancedEntropyController(
315
+ min_entropy=0.5,
316
+ max_entropy=3.0,
317
+ target_entropy=1.5,
318
+ adaptation_rate=0.01,
319
+ )
320
+
321
+
322
+ def forward(
323
+ self,
324
+ input_ids: Optional[torch.LongTensor] = None,
325
+ attention_mask: Optional[torch.FloatTensor] = None,
326
+ labels: Optional[torch.LongTensor] = None,
327
+ **kwargs
328
+ ):
329
+ # Default mask if not provided
330
+ if attention_mask is None and input_ids is not None:
331
+ attention_mask = (input_ids != self.config.pad_token_id).long()
332
+
333
+ # Respect caller settings, only set defaults if missing
334
+ kwargs.setdefault("output_hidden_states", True)
335
+ kwargs.setdefault("return_dict", True)
336
+
337
+ outputs = super().forward(
338
+ input_ids=input_ids,
339
+ attention_mask=attention_mask,
340
+ labels=None,
341
+ **kwargs
342
+ )
343
+
344
+ hidden_states = outputs.hidden_states[-1]
345
+ lm_logits = outputs.logits
346
+ loss = None
347
+
348
+ if self.training and self.use_mtp_training and labels is not None: # labels, not kwargs
349
+ mtp_outputs = self.mtp_head(hidden_states)
350
+ horizon_loss_dict = self.horizon_loss(mtp_outputs, input_ids, attention_mask)
351
+
352
+ shift_logits = lm_logits[..., :-1, :].contiguous()
353
+ shift_labels = labels[..., 1:].contiguous() # labels, not kwargs["labels"]
354
+
355
+ if attention_mask is not None:
356
+ shift_mask = attention_mask[..., 1:].contiguous()
357
+ loss_mask = shift_mask.view(-1) == 1
358
+ if loss_mask.sum() == 0:
359
+ causal_lm_loss = torch.tensor(0.0, device=lm_logits.device)
360
+ else:
361
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))[loss_mask]
362
+ flat_labels = shift_labels.view(-1)[loss_mask]
363
+ causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
364
+ else:
365
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))
366
+ flat_labels = shift_labels.view(-1)
367
+ causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
368
+
369
+ loss = 0.7 * horizon_loss_dict['loss'] + 0.3 * causal_lm_loss
370
+
371
+ elif labels is not None: # labels, not kwargs.get("labels")
372
+ shift_logits = lm_logits[..., :-1, :].contiguous()
373
+ shift_labels = labels[..., 1:].contiguous() # labels, not kwargs["labels"]
374
+ loss = F.cross_entropy(
375
+ shift_logits.view(-1, shift_logits.size(-1)),
376
+ shift_labels.view(-1)
377
+ )
378
+
379
+ from transformers.modeling_outputs import CausalLMOutputWithPast
380
+ return CausalLMOutputWithPast(
381
+ loss=loss,
382
+ logits=lm_logits,
383
+ past_key_values=outputs.past_key_values,
384
+ hidden_states=outputs.hidden_states,
385
+ attentions=outputs.attentions,
386
+ )
387
+
388
+ def set_mtp_training(self, use_mtp: bool):
389
+ self.use_mtp_training = use_mtp
390
+
391
+ # ================
392
+ # RL SAMPLING + PPO
393
+ # ================
394
+
395
+ def generate_with_logprobs(
396
+ self,
397
+ input_ids: torch.LongTensor,
398
+ max_new_tokens: int = 50,
399
+ temperature: float = 1.0,
400
+ top_k: Optional[int] = None,
401
+ top_p: Optional[float] = None,
402
+ do_sample: bool = True,
403
+ return_probs: bool = True,
404
+ tokenizer=None, # allow passing explicitly
405
+ ) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
406
+ self.eval()
407
+ device = input_ids.device
408
+
409
+ # Normalize shapes: allow [L], [1,L], [B,L], [B,1,L]
410
+ if input_ids.dim() == 1:
411
+ input_ids = input_ids.unsqueeze(0) # [L] -> [1,L]
412
+ if input_ids.dim() == 3 and input_ids.size(1) == 1:
413
+ input_ids = input_ids.squeeze(1) # [B,1,L] -> [B,L]
414
+ assert input_ids.dim() == 2, f"input_ids must be 2-D, got {input_ids.shape}"
415
+
416
+ batch_size, seq_len = input_ids.shape
417
+ current_input = input_ids
418
+
419
+ generated_tokens, generated_logprobs, generated_probs = [], [], []
420
+
421
+ with torch.no_grad():
422
+ for _ in range(max_new_tokens):
423
+ outputs = self(current_input, use_cache=False)
424
+ logits = outputs.logits[:, -1, :] / temperature
425
+
426
+ # Top-k
427
+ if top_k is not None:
428
+ values, indices = torch.topk(logits, k=top_k)
429
+ logits = torch.full_like(logits, float("-inf"))
430
+ logits.scatter_(1, indices, values)
431
+
432
+ # Top-p
433
+ if top_p is not None and top_p < 1.0:
434
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
435
+ cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
436
+ mask = cumprobs > top_p
437
+ mask[..., 1:] = mask[..., :-1].clone()
438
+ mask[..., 0] = False
439
+ logits[mask.scatter(1, sorted_indices, mask)] = float("-inf")
440
+
441
+ probs = F.softmax(logits, dim=-1)
442
+
443
+ if do_sample:
444
+ dist = Categorical(probs)
445
+ next_token = dist.sample()
446
+ log_p = dist.log_prob(next_token)
447
+ else:
448
+ next_token = torch.argmax(probs, dim=-1)
449
+ log_p = torch.log(torch.gather(probs, 1, next_token.unsqueeze(1))).squeeze(1)
450
+
451
+ generated_tokens.append(next_token.unsqueeze(1))
452
+ generated_logprobs.append(log_p.unsqueeze(1))
453
+ if return_probs:
454
+ generated_probs.append(probs.unsqueeze(1))
455
+
456
+ current_input = torch.cat([current_input, next_token.unsqueeze(1)], dim=1)
457
+
458
+ generated_tokens = torch.cat(generated_tokens, dim=1) # [B, T]
459
+ generated_logprobs = torch.cat(generated_logprobs, dim=1) # [B, T]
460
+ generated_probs = torch.cat(generated_probs, dim=1) if return_probs else None
461
+
462
+ # Use passed tokenizer, fallback to self.tokenizer
463
+ tok = tokenizer if tokenizer is not None else getattr(self, "tokenizer", None)
464
+ if tok is None:
465
+ raise ValueError("Tokenizer must be provided to decode generated tokens.")
466
+
467
+ decoded_list = [
468
+ tok.decode(tok_ids, skip_special_tokens=True)
469
+ for tok_ids in generated_tokens
470
+ ]
471
+ return decoded_list, generated_logprobs, generated_tokens, generated_probs
472
+
473
+
474
+ def ppo_step(
475
+ self,
476
+ input_ids: torch.LongTensor,
477
+ old_log_probs: torch.Tensor,
478
+ old_action_probs: torch.Tensor,
479
+ tokenizer,
480
+ max_new_tokens: int = 50,
481
+ temperature: float = 1.0,
482
+ top_k: Optional[int] = 50,
483
+ top_p: Optional[float] = 0.95,
484
+ validity_weight: float = 1.0,
485
+ lipinski_weight: float = 1.0,
486
+ entropy_weight: float = 0.01,
487
+ clip_epsilon: float = 0.2,
488
+ baseline: Optional[torch.Tensor] = None,
489
+ reward_mode: str = "chemq3", # "chemq3", "sa", or "mix"
490
+ reward_mix: float = 0.5, # used if mixing chemq3 + sa (0..1 weight for chemq3)
491
+ ) -> Dict[str, Any]:
492
+
493
+ # =========================
494
+ # PPO-KL BODY (drop-in)
495
+ # =========================
496
+ self.train()
497
+ self.set_mtp_training(False)
498
+ if not hasattr(self, 'tokenizer'):
499
+ self.tokenizer = tokenizer
500
+
501
+ # Ensure entropy controller exists
502
+ if not hasattr(self, 'entropy_controller'):
503
+ # if you want different defaults, set them when constructing model instead
504
+ self.entropy_controller = EnhancedEntropyController(
505
+ min_entropy=0.5,
506
+ max_entropy=3.0,
507
+ target_entropy=1.5,
508
+ adaptation_rate=0.01
509
+ )
510
+
511
+ # --- roll-out ---
512
+ selfies_list, new_log_probs, token_ids, new_action_probs = self.generate_with_logprobs(
513
+ input_ids=input_ids,
514
+ max_new_tokens=max_new_tokens,
515
+ temperature=temperature,
516
+ top_k=top_k,
517
+ top_p=top_p,
518
+ do_sample=True,
519
+ return_probs=True,
520
+ tokenizer=getattr(self, "tokenizer", None),
521
+ )
522
+
523
+ batch_size = len(selfies_list)
524
+ device = new_log_probs.device
525
+
526
+ # --- rewards: compute depending on mode ---
527
+ validity_vals: List[float] = []
528
+ lipinski_vals: List[float] = []
529
+ total_rewards: List[float] = []
530
+ sa_rewards: List[float] = []
531
+
532
+ for s in selfies_list:
533
+ if reward_mode == "chemq3":
534
+ r = compute_comprehensive_reward(s)
535
+ validity_vals.append(r.get('validity', 0.0))
536
+ lipinski_vals.append(r.get('lipinski', 0.0))
537
+ total_rewards.append(r.get('total', 0.0))
538
+
539
+ elif reward_mode == "sa":
540
+ sa = compute_sa_reward(s)
541
+ sa_rewards.append(sa)
542
+
543
+ elif reward_mode == "mix":
544
+ r = compute_comprehensive_reward(s)
545
+ sa = compute_sa_reward(s)
546
+ mixed = reward_mix * r.get("total", 0.0) + (1.0 - reward_mix) * sa
547
+ total_rewards.append(mixed)
548
+ sa_rewards.append(sa)
549
+ validity_vals.append(r.get('validity', 0.0))
550
+ lipinski_vals.append(r.get('lipinski', 0.0))
551
+
552
+ else:
553
+ # unknown mode -> default to zero reward
554
+ total_rewards.append(0.0)
555
+ validity_vals.append(0.0)
556
+ lipinski_vals.append(0.0)
557
+
558
+ # Convert lists -> tensors, handle empty lists safely
559
+ if reward_mode in ("chemq3", "mix"):
560
+ rewards = torch.tensor(total_rewards, dtype=torch.float32, device=device)
561
+ elif reward_mode == "sa":
562
+ rewards = torch.tensor(sa_rewards, dtype=torch.float32, device=device)
563
+ else:
564
+ rewards = torch.zeros(batch_size, dtype=torch.float32, device=device)
565
+
566
+ if len(validity_vals) > 0:
567
+ validity_rewards = torch.tensor(validity_vals, dtype=torch.float32, device=device)
568
+ else:
569
+ validity_rewards = torch.zeros(batch_size, dtype=torch.float32, device=device)
570
+
571
+ if len(lipinski_vals) > 0:
572
+ lipinski_rewards = torch.tensor(lipinski_vals, dtype=torch.float32, device=device)
573
+ else:
574
+ lipinski_rewards = torch.zeros(batch_size, dtype=torch.float32, device=device)
575
+
576
+ # baseline subtraction (broadcast if needed)
577
+ if baseline is not None:
578
+ # baseline can be scalar tensor or per-batch; support both
579
+ if baseline.numel() == 1:
580
+ rewards = rewards - baseline.to(device)
581
+ else:
582
+ rewards = rewards - baseline.to(device)
583
+
584
+ # --- probability ratio ---
585
+ # old_action_probs/new_action_probs expected shape: [B, T, V]
586
+ # token_ids expected shape: [B, T]
587
+ old_probs = torch.gather(old_action_probs, 2, token_ids.unsqueeze(2)).squeeze(2).clamp_min(1e-8)
588
+ new_probs = torch.gather(new_action_probs, 2, token_ids.unsqueeze(2)).squeeze(2).clamp_min(1e-8)
589
+ log_ratio = new_log_probs - old_log_probs # shape [B, T]
590
+ # total_ratio: product of per-step ratios -> exp(sum(log ratio))
591
+ total_ratio = torch.exp(log_ratio.sum(dim=1)) # shape [B]
592
+
593
+ # --- adaptive KL controller (singleton) ---
594
+ if not hasattr(self, 'kl_controller'):
595
+ self.kl_controller = AdaptiveKLController()
596
+ # KL per example: sum over time of old * (log old - log new), averaged over V already via gather
597
+ # Here compute KL between full distributions if available
598
+ kl = (old_probs * (torch.log(old_probs) - torch.log(new_probs))).sum(dim=1) # shape [B]
599
+ beta = self.kl_controller.update(kl.mean().item())
600
+
601
+ # --- PPO-KL loss ---
602
+ surr1 = total_ratio * rewards
603
+ surr2 = torch.clamp(total_ratio, 1 - clip_epsilon, 1 + clip_epsilon) * rewards
604
+ ppo_loss = -torch.min(surr1, surr2).mean()
605
+ kl_penalty = beta * kl.mean()
606
+ total_policy_loss = ppo_loss + kl_penalty
607
+
608
+ # --- entropy bonus (adaptive) ---
609
+ # compute token-level entropy averaged across batch/time
610
+ with torch.no_grad():
611
+ _probs = new_action_probs.clamp_min(1e-12)
612
+ per_step_entropy = -(_probs * torch.log(_probs)).sum(dim=-1) # [B, T]
613
+ entropy = per_step_entropy.mean() # scalar tensor
614
+
615
+ adaptive_entropy_weight = self.entropy_controller.update_entropy_weight(entropy.item())
616
+ entropy_bonus = adaptive_entropy_weight * entropy
617
+ total_loss = total_policy_loss - entropy_bonus
618
+
619
+ # regularization (optional)
620
+ reg_loss = 1e-7 * sum(p.pow(2).sum() for p in self.parameters())
621
+ total_loss = total_loss + reg_loss
622
+
623
+ # prepare return (detach tensors where relevant)
624
+ avg_sa = None
625
+ if len(sa_rewards) > 0:
626
+ avg_sa = float(torch.tensor(sa_rewards, dtype=torch.float32, device=device).mean().item())
627
+
628
+ return {
629
+ 'loss': total_loss,
630
+ 'ppo_loss': ppo_loss.item(),
631
+ 'kl_penalty': kl_penalty.item(),
632
+ 'kl_coef': beta,
633
+ 'entropy': float(entropy.item()),
634
+ 'entropy_weight': float(adaptive_entropy_weight),
635
+ 'validity_rate': float(validity_rewards.mean().item()),
636
+ 'lipinski_score': float(lipinski_rewards.mean().item()),
637
+ 'avg_reward': float(rewards.mean().item()),
638
+ 'avg_sa_reward': avg_sa,
639
+ 'generated_selfies': selfies_list,
640
+ 'generated_smiles': [selfies_to_smiles(s) for s in selfies_list],
641
+ 'new_log_probs': new_log_probs.detach(),
642
+ 'new_action_probs': new_action_probs.detach(),
643
+ }
644
+
645
+
646
+
647
+
648
+ # ========================
649
+ # CURRICULUM LEARNING MANAGER
650
+ # ========================
651
+
652
+ class CurriculumManager:
653
+ def __init__(self, start_len=10, max_len=30, step_increase=5, steps_per_level=30):
654
+ """
655
+ Cyclic curriculum:
656
+ - Gradually increases max_new_tokens from start_len → max_len
657
+ - After reaching max_len, resets back to start_len and repeats
658
+ """
659
+ self.start_len = start_len
660
+ self.max_len = max_len
661
+ self.step_increase = step_increase
662
+ self.steps_per_level = steps_per_level
663
+ self.step_counter = 0
664
+ self.current_max_len = start_len
665
+
666
+ def get_max_new_tokens(self):
667
+ return self.current_max_len
668
+
669
+ def step(self):
670
+ self.step_counter += 1
671
+ if self.step_counter % self.steps_per_level == 0:
672
+ if self.current_max_len < self.max_len:
673
+ self.current_max_len = min(self.current_max_len + self.step_increase, self.max_len)
674
+ else:
675
+ # Reset cycle
676
+ self.current_max_len = self.start_len
677
+ print(f" 🔄 Cycle reset: max_new_tokens -> {self.current_max_len}")
678
+ if self.current_max_len < self.max_len:
679
+ print(f" 📈 Curriculum Update: max_new_tokens = {self.current_max_len}")
680
+ return self.current_max_len
681
+
682
+ class AdaptiveKLController:
683
+ """
684
+ Increases or decreases β so that E[KL] stays ≈ target_kl.
685
+ """
686
+ def __init__(self, init_kl_coef: float = 0.1, target_kl: float = 0.01,
687
+ kl_horizon: int = 1000, increase_rate: float = 1.5, decrease_rate: float = 0.8):
688
+ self.kl_coef = init_kl_coef
689
+ self.target_kl = target_kl
690
+ self.kl_horizon = kl_horizon
691
+ self.inc = increase_rate
692
+ self.dec = decrease_rate
693
+ self.buffer = []
694
+
695
+ def update(self, kl: float):
696
+ self.buffer.append(kl)
697
+ if len(self.buffer) >= self.kl_horizon:
698
+ avg_kl = sum(self.buffer) / len(self.buffer)
699
+ self.buffer.clear()
700
+ if avg_kl > self.target_kl * 1.5:
701
+ self.kl_coef *= self.inc
702
+ elif avg_kl < self.target_kl * 0.5:
703
+ self.kl_coef *= self.dec
704
+ return self.kl_coef
705
+
706
+
707
+ class EnhancedEntropyController:
708
+ """
709
+ More sophisticated entropy control with dynamic targets and temperature scheduling.
710
+ """
711
+ def __init__(self, min_entropy: float = 0.5, max_entropy: float = 3.0,
712
+ target_entropy: float = 1.5, adaptation_rate: float = 0.01):
713
+ self.min_entropy = min_entropy
714
+ self.max_entropy = max_entropy
715
+ self.target_entropy = target_entropy
716
+ self.adaptation_rate = adaptation_rate
717
+ self.entropy_history = []
718
+ self.entropy_weight = 0.01 # Starting weight
719
+
720
+ def update_entropy_weight(self, current_entropy: float) -> float:
721
+ """
722
+ Dynamically adjust entropy weight based on current entropy levels.
723
+ """
724
+ self.entropy_history.append(current_entropy)
725
+
726
+ # Keep rolling window
727
+ if len(self.entropy_history) > 100:
728
+ self.entropy_history = self.entropy_history[-100:]
729
+
730
+ if len(self.entropy_history) >= 10:
731
+ avg_entropy = np.mean(self.entropy_history[-10:])
732
+
733
+ # If entropy too low, increase weight to encourage exploration
734
+ if avg_entropy < self.target_entropy * 0.8:
735
+ self.entropy_weight = min(0.05, self.entropy_weight * 1.1)
736
+ # If entropy too high, decrease weight
737
+ elif avg_entropy > self.target_entropy * 1.2:
738
+ self.entropy_weight = max(0.001, self.entropy_weight * 0.95)
739
+
740
+ return self.entropy_weight
741
+
742
+ def compute_entropy_reward(self, entropy: float) -> float:
743
+ """
744
+ Reward function for entropy - prefer target range.
745
+ """
746
+ if self.min_entropy <= entropy <= self.max_entropy:
747
+ # Gaussian reward centered at target
748
+ distance = abs(entropy - self.target_entropy)
749
+ max_distance = max(self.target_entropy - self.min_entropy,
750
+ self.max_entropy - self.target_entropy)
751
+ return np.exp(-(distance / max_distance) ** 2)
752
+ else:
753
+ return 0.1 # Small penalty for being outside range
FastChemTokenizerHF.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import os
4
+ from typing import List, Union, Optional, Tuple, Dict, Any
5
+ from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
6
+ from transformers.utils import PaddingStrategy, TensorType
7
+ from functools import lru_cache
8
+
9
+
10
+ class TrieNode:
11
+ __slots__ = ['children', 'token_id']
12
+ def __init__(self):
13
+ self.children = {}
14
+ self.token_id = None # If set, this node completes a valid token
15
+
16
+
17
+ class FastChemTokenizer(PreTrainedTokenizerBase):
18
+ """
19
+ Fully HuggingFace API compatible tokenizer for chemical representations.
20
+ """
21
+
22
+ vocab_files_names = {"vocab_file": "vocab.json"}
23
+
24
+ def __init__(
25
+ self,
26
+ token_to_id=None,
27
+ vocab_file=None,
28
+ model_max_length=512,
29
+ padding_side="right",
30
+ truncation_side="right",
31
+ chat_template=None,
32
+ **kwargs
33
+ ):
34
+ # Handle vocab loading
35
+ if token_to_id is None and vocab_file is None:
36
+ raise ValueError("Either token_to_id or vocab_file must be provided")
37
+
38
+ if vocab_file is not None:
39
+ with open(vocab_file, "r", encoding="utf-8") as f:
40
+ token_to_id = json.load(f)
41
+ token_to_id = {str(k): int(v) for k, v in token_to_id.items()}
42
+
43
+ self.token_to_id = token_to_id
44
+ self.id_to_token = {v: k for k, v in token_to_id.items()}
45
+
46
+ # Precompute max token length for possible use & clarity
47
+ self.max_token_len = max(len(t) for t in token_to_id.keys()) if token_to_id else 0
48
+
49
+ # Build trie for fast longest-match lookup
50
+ self.trie_root = self._build_trie(token_to_id)
51
+
52
+ # Validate required special tokens
53
+ required_special_tokens = ["<s>", "</s>", "<pad>", "<unk>", "<mask>"]
54
+ for tok in required_special_tokens:
55
+ if tok not in token_to_id:
56
+ raise KeyError(f"Required special token '{tok}' not found in vocab.")
57
+
58
+ # ✅ Assign special token IDs explicitly
59
+ self.bos_token_id = token_to_id["<s>"]
60
+ self.eos_token_id = token_to_id["</s>"]
61
+ self.pad_token_id = token_to_id["<pad>"]
62
+ self.unk_token_id = token_to_id["<unk>"]
63
+ self.mask_token_id = token_to_id["<mask>"]
64
+
65
+ # Special tokens
66
+ bos_token = "<s>"
67
+ eos_token = "</s>"
68
+ pad_token = "<pad>"
69
+ unk_token = "<unk>"
70
+ mask_token = "<mask>"
71
+
72
+ # Initialize parent class with all required parameters
73
+ super().__init__(
74
+ bos_token=bos_token,
75
+ eos_token=eos_token,
76
+ unk_token=unk_token,
77
+ sep_token=None,
78
+ pad_token=pad_token,
79
+ cls_token=None,
80
+ mask_token=mask_token,
81
+ additional_special_tokens=[],
82
+ model_max_length=model_max_length,
83
+ padding_side=padding_side,
84
+ truncation_side=truncation_side,
85
+ chat_template=chat_template,
86
+ **kwargs,
87
+ )
88
+
89
+ def _build_trie(self, token_to_id):
90
+ root = TrieNode()
91
+ for token, tid in token_to_id.items():
92
+ node = root
93
+ for char in token:
94
+ if char not in node.children:
95
+ node.children[char] = TrieNode()
96
+ node = node.children[char]
97
+ node.token_id = tid
98
+ return root
99
+
100
+ @property
101
+ def vocab_size(self):
102
+ return len(self.token_to_id)
103
+
104
+ def __len__(self):
105
+ return len(self.token_to_id)
106
+
107
+ def get_vocab(self) -> Dict[str, int]:
108
+ return self.token_to_id.copy()
109
+
110
+ @lru_cache(maxsize=10000)
111
+ def _cached_encode_str(self, s: str) -> Tuple[int, ...]:
112
+ return tuple(self._encode_core(s))
113
+
114
+ def _encode_core(self, text: str) -> List[int]:
115
+ """Core encoding logic using Trie — no caching."""
116
+ tokens = text
117
+ result_ids = []
118
+ i = 0
119
+ n = len(tokens)
120
+
121
+ while i < n:
122
+ node = self.trie_root
123
+ j = i
124
+ last_match_id = None
125
+ last_match_end = i
126
+
127
+ while j < n and tokens[j] in node.children:
128
+ node = node.children[tokens[j]]
129
+ j += 1
130
+ if node.token_id is not None:
131
+ last_match_id = node.token_id
132
+ last_match_end = j
133
+
134
+ if last_match_id is not None:
135
+ result_ids.append(last_match_id)
136
+ i = last_match_end
137
+ else:
138
+ tok = tokens[i]
139
+ result_ids.append(self.token_to_id.get(tok, self.unk_token_id))
140
+ i += 1
141
+
142
+ return result_ids
143
+
144
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
145
+ token_ids = self._encode_core(text.strip())
146
+ return [self.id_to_token[tid] for tid in token_ids]
147
+
148
+ def _convert_token_to_id(self, token: str) -> int:
149
+ return self.token_to_id.get(token, self.unk_token_id)
150
+
151
+ def _convert_id_to_token(self, index: int) -> str:
152
+ return self.id_to_token.get(index, self.unk_token)
153
+
154
+ # ✅ Public methods
155
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
156
+ if isinstance(tokens, str):
157
+ return self._convert_token_to_id(tokens)
158
+ return [self._convert_token_to_id(tok) for tok in tokens]
159
+
160
+ def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
161
+ if isinstance(ids, int):
162
+ return self._convert_id_to_token(ids)
163
+ return [self._convert_id_to_token(i) for i in ids]
164
+
165
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
166
+ """SMILES-style decoding: no spaces between tokens."""
167
+ return "".join(tokens)
168
+
169
+ def encode(
170
+ self,
171
+ text: str,
172
+ text_pair: Optional[str] = None,
173
+ add_special_tokens: bool = True,
174
+ padding: bool = False,
175
+ truncation: bool = False,
176
+ max_length: Optional[int] = None,
177
+ return_tensors: Optional[str] = None,
178
+ ) -> List[int]:
179
+ encoded = self.encode_plus(
180
+ text=text,
181
+ text_pair=text_pair,
182
+ add_special_tokens=add_special_tokens,
183
+ padding=padding,
184
+ truncation=truncation,
185
+ max_length=max_length,
186
+ return_tensors=return_tensors,
187
+ )
188
+
189
+ input_ids = encoded["input_ids"]
190
+ if isinstance(input_ids, torch.Tensor):
191
+ if input_ids.dim() > 1:
192
+ input_ids = input_ids.squeeze(0)
193
+ input_ids = input_ids.tolist()
194
+
195
+ return input_ids
196
+
197
+ def decode(
198
+ self,
199
+ token_ids: Union[List[int], torch.Tensor],
200
+ skip_special_tokens: bool = False,
201
+ clean_up_tokenization_spaces: bool = None,
202
+ **kwargs
203
+ ) -> str:
204
+ if isinstance(token_ids, torch.Tensor):
205
+ token_ids = token_ids.tolist()
206
+
207
+ if skip_special_tokens:
208
+ special_ids = {
209
+ self.bos_token_id,
210
+ self.eos_token_id,
211
+ self.pad_token_id,
212
+ self.mask_token_id,
213
+ }
214
+ else:
215
+ special_ids = set()
216
+
217
+ tokens = []
218
+ for tid in token_ids:
219
+ if tid in special_ids:
220
+ continue
221
+ token = self.id_to_token.get(tid, self.unk_token)
222
+ tokens.append(token)
223
+
224
+ return "".join(tokens)
225
+
226
+ def batch_decode(
227
+ self,
228
+ sequences: Union[List[List[int]], torch.Tensor],
229
+ skip_special_tokens: bool = False,
230
+ clean_up_tokenization_spaces: bool = None,
231
+ **kwargs
232
+ ) -> List[str]:
233
+ """Batch decode sequences."""
234
+ if isinstance(sequences, torch.Tensor):
235
+ sequences = sequences.tolist()
236
+
237
+ return [
238
+ self.decode(
239
+ seq,
240
+ skip_special_tokens=skip_special_tokens,
241
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
242
+ **kwargs
243
+ )
244
+ for seq in sequences
245
+ ]
246
+
247
+ def decode_with_trace(self, token_ids: List[int]) -> None:
248
+ print(f"\n🔍 Decoding {len(token_ids)} tokens:")
249
+ for i, tid in enumerate(token_ids):
250
+ token = self.id_to_token.get(tid, self.unk_token)
251
+ print(f" [{i:03d}] ID={tid:5d} → '{token}'")
252
+
253
+ def __call__(
254
+ self,
255
+ text: Union[str, List[str]],
256
+ text_pair: Optional[Union[str, List[str]]] = None,
257
+ add_special_tokens: bool = True,
258
+ padding: Union[bool, str, PaddingStrategy] = False,
259
+ truncation: Union[bool, str] = False,
260
+ max_length: Optional[int] = None,
261
+ stride: int = 0,
262
+ is_split_into_words: bool = False,
263
+ pad_to_multiple_of: Optional[int] = None,
264
+ return_tensors: Optional[Union[str, TensorType]] = None,
265
+ return_token_type_ids: Optional[bool] = None,
266
+ return_attention_mask: Optional[bool] = None,
267
+ return_overflowing_tokens: bool = False,
268
+ return_special_tokens_mask: bool = False,
269
+ return_offsets_mapping: bool = False,
270
+ return_length: bool = False,
271
+ verbose: bool = True,
272
+ **kwargs
273
+ ) -> BatchEncoding:
274
+ """
275
+ Main callable method that handles both single and batch inputs.
276
+ """
277
+ # Handle defaults
278
+ if return_token_type_ids is None:
279
+ return_token_type_ids = True
280
+ if return_attention_mask is None:
281
+ return_attention_mask = True
282
+
283
+ if isinstance(text, list):
284
+ if text_pair is not None:
285
+ batch = [(t, p) for t, p in zip(text, text_pair)]
286
+ else:
287
+ batch = text
288
+ return self.batch_encode_plus(
289
+ batch,
290
+ add_special_tokens=add_special_tokens,
291
+ padding=padding,
292
+ truncation=truncation,
293
+ max_length=max_length,
294
+ stride=stride,
295
+ is_split_into_words=is_split_into_words,
296
+ pad_to_multiple_of=pad_to_multiple_of,
297
+ return_tensors=return_tensors,
298
+ return_token_type_ids=return_token_type_ids,
299
+ return_attention_mask=return_attention_mask,
300
+ return_overflowing_tokens=return_overflowing_tokens,
301
+ return_special_tokens_mask=return_special_tokens_mask,
302
+ return_offsets_mapping=return_offsets_mapping,
303
+ return_length=return_length,
304
+ verbose=verbose,
305
+ **kwargs
306
+ )
307
+ else:
308
+ return self.encode_plus(
309
+ text=text,
310
+ text_pair=text_pair,
311
+ add_special_tokens=add_special_tokens,
312
+ padding=padding,
313
+ truncation=truncation,
314
+ max_length=max_length,
315
+ stride=stride,
316
+ is_split_into_words=is_split_into_words,
317
+ pad_to_multiple_of=pad_to_multiple_of,
318
+ return_tensors=return_tensors,
319
+ return_token_type_ids=return_token_type_ids,
320
+ return_attention_mask=return_attention_mask,
321
+ return_overflowing_tokens=return_overflowing_tokens,
322
+ return_special_tokens_mask=return_special_tokens_mask,
323
+ return_offsets_mapping=return_offsets_mapping,
324
+ return_length=return_length,
325
+ verbose=verbose,
326
+ **kwargs
327
+ )
328
+
329
+ def encode_plus(
330
+ self,
331
+ text: str,
332
+ text_pair: Optional[str] = None,
333
+ add_special_tokens: bool = True,
334
+ padding: Union[bool, str, PaddingStrategy] = False,
335
+ truncation: Union[bool, str] = False,
336
+ max_length: Optional[int] = None,
337
+ stride: int = 0,
338
+ is_split_into_words: bool = False,
339
+ pad_to_multiple_of: Optional[int] = None,
340
+ return_tensors: Optional[Union[str, TensorType]] = None,
341
+ return_token_type_ids: Optional[bool] = True,
342
+ return_attention_mask: Optional[bool] = True,
343
+ return_overflowing_tokens: bool = False,
344
+ return_special_tokens_mask: bool = False,
345
+ return_offsets_mapping: bool = False,
346
+ return_length: bool = False,
347
+ verbose: bool = True,
348
+ **kwargs
349
+ ) -> BatchEncoding:
350
+ if max_length is None:
351
+ max_length = self.model_max_length
352
+
353
+ ids_a = list(self._cached_encode_str(text.strip()))
354
+
355
+ if text_pair is not None:
356
+ ids_b = list(self._cached_encode_str(text_pair.strip()))
357
+ else:
358
+ ids_b = None
359
+
360
+ input_ids = []
361
+ token_type_ids = []
362
+
363
+ if add_special_tokens:
364
+ input_ids.append(self.bos_token_id)
365
+ token_type_ids.append(0)
366
+ if ids_b is not None:
367
+ input_ids.extend(ids_a)
368
+ token_type_ids.extend([0] * len(ids_a))
369
+ input_ids.append(self.eos_token_id)
370
+ token_type_ids.append(0)
371
+
372
+ input_ids.extend(ids_b)
373
+ token_type_ids.extend([1] * len(ids_b))
374
+ input_ids.append(self.eos_token_id)
375
+ token_type_ids.append(1)
376
+ else:
377
+ input_ids.extend(ids_a)
378
+ token_type_ids.extend([0] * len(ids_a))
379
+ input_ids.append(self.eos_token_id)
380
+ token_type_ids.append(0)
381
+ else:
382
+ input_ids = ids_a.copy()
383
+ token_type_ids = [0] * len(input_ids)
384
+ if ids_b is not None:
385
+ input_ids.extend(ids_b)
386
+ token_type_ids.extend([1] * len(ids_b))
387
+
388
+ # Handle truncation
389
+ if truncation and len(input_ids) > max_length:
390
+ input_ids = input_ids[:max_length]
391
+ token_type_ids = token_type_ids[:max_length]
392
+
393
+ # Handle padding
394
+ if padding == True or padding == "max_length":
395
+ pad_len = max_length - len(input_ids)
396
+ if pad_len > 0:
397
+ if self.padding_side == "right":
398
+ input_ids.extend([self.pad_token_id] * pad_len)
399
+ token_type_ids.extend([0] * pad_len)
400
+ else:
401
+ input_ids = [self.pad_token_id] * pad_len + input_ids
402
+ token_type_ids = [0] * pad_len + token_type_ids
403
+
404
+ attention_mask = [1 if tid != self.pad_token_id else 0 for tid in input_ids]
405
+
406
+ encoded_dict = {
407
+ "input_ids": input_ids,
408
+ }
409
+
410
+ if return_attention_mask:
411
+ encoded_dict["attention_mask"] = attention_mask
412
+
413
+ if return_token_type_ids:
414
+ encoded_dict["token_type_ids"] = token_type_ids
415
+
416
+ if return_special_tokens_mask:
417
+ special_tokens_mask = [
418
+ 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id} else 0
419
+ for tid in input_ids
420
+ ]
421
+ encoded_dict["special_tokens_mask"] = special_tokens_mask
422
+
423
+ if return_length:
424
+ encoded_dict["length"] = len([tid for tid in input_ids if tid != self.pad_token_id])
425
+
426
+ if return_tensors == "pt":
427
+ output = {}
428
+ for k, v in encoded_dict.items():
429
+ tensor = torch.tensor(v, dtype=torch.long)
430
+ if tensor.ndim == 1:
431
+ tensor = tensor.unsqueeze(0)
432
+ output[k] = tensor
433
+ else:
434
+ output = encoded_dict
435
+
436
+ return BatchEncoding(output, tensor_type=return_tensors)
437
+
438
+ def batch_encode_plus(
439
+ self,
440
+ batch_text_or_text_pairs: List[Union[str, Tuple[str, str]]],
441
+ add_special_tokens: bool = True,
442
+ padding: Union[bool, str, PaddingStrategy] = False,
443
+ truncation: Union[bool, str] = False,
444
+ max_length: Optional[int] = None,
445
+ stride: int = 0,
446
+ is_split_into_words: bool = False,
447
+ pad_to_multiple_of: Optional[int] = None,
448
+ return_tensors: Optional[Union[str, TensorType]] = None,
449
+ return_token_type_ids: Optional[bool] = True,
450
+ return_attention_mask: Optional[bool] = True,
451
+ return_overflowing_tokens: bool = False,
452
+ return_special_tokens_mask: bool = False,
453
+ return_offsets_mapping: bool = False,
454
+ return_length: bool = False,
455
+ verbose: bool = True,
456
+ **kwargs
457
+ ) -> BatchEncoding:
458
+ all_input_ids = []
459
+ all_attention_masks = []
460
+ all_token_type_ids = []
461
+ all_special_tokens_masks = []
462
+ all_lengths = []
463
+
464
+ for item in batch_text_or_text_pairs:
465
+ if isinstance(item, tuple):
466
+ text, text_pair = item
467
+ else:
468
+ text, text_pair = item, None
469
+
470
+ encoded = self.encode_plus(
471
+ text=text,
472
+ text_pair=text_pair,
473
+ add_special_tokens=add_special_tokens,
474
+ padding=False, # We'll handle batch padding later
475
+ truncation=truncation,
476
+ max_length=max_length,
477
+ stride=stride,
478
+ is_split_into_words=is_split_into_words,
479
+ pad_to_multiple_of=pad_to_multiple_of,
480
+ return_tensors=None, # Don't convert to tensors yet
481
+ return_token_type_ids=return_token_type_ids,
482
+ return_attention_mask=return_attention_mask,
483
+ return_overflowing_tokens=return_overflowing_tokens,
484
+ return_special_tokens_mask=return_special_tokens_mask,
485
+ return_offsets_mapping=return_offsets_mapping,
486
+ return_length=return_length,
487
+ verbose=verbose,
488
+ **kwargs
489
+ )
490
+
491
+ all_input_ids.append(encoded["input_ids"])
492
+ if "attention_mask" in encoded:
493
+ all_attention_masks.append(encoded["attention_mask"])
494
+ if "token_type_ids" in encoded:
495
+ all_token_type_ids.append(encoded["token_type_ids"])
496
+ if "special_tokens_mask" in encoded:
497
+ all_special_tokens_masks.append(encoded["special_tokens_mask"])
498
+ if "length" in encoded:
499
+ all_lengths.append(encoded["length"])
500
+
501
+ batched = {
502
+ "input_ids": all_input_ids,
503
+ }
504
+
505
+ if all_attention_masks:
506
+ batched["attention_mask"] = all_attention_masks
507
+ if all_token_type_ids:
508
+ batched["token_type_ids"] = all_token_type_ids
509
+ if all_special_tokens_masks:
510
+ batched["special_tokens_mask"] = all_special_tokens_masks
511
+ if all_lengths:
512
+ batched["length"] = all_lengths
513
+
514
+ # Handle batch padding
515
+ if padding == True or padding == "longest":
516
+ max_len = max(len(ids) for ids in all_input_ids)
517
+ for key in batched:
518
+ if key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
519
+ padded_seqs = []
520
+ for seq in batched[key]:
521
+ pad_len = max_len - len(seq)
522
+ if pad_len > 0:
523
+ if key == "input_ids":
524
+ padding_value = self.pad_token_id
525
+ else:
526
+ padding_value = 0
527
+
528
+ if self.padding_side == "right":
529
+ padded_seq = seq + [padding_value] * pad_len
530
+ else:
531
+ padded_seq = [padding_value] * pad_len + seq
532
+ else:
533
+ padded_seq = seq
534
+ padded_seqs.append(padded_seq)
535
+ batched[key] = padded_seqs
536
+
537
+ if return_tensors == "pt":
538
+ def to_tensor_list(lst):
539
+ return [torch.tensor(item, dtype=torch.long) for item in lst]
540
+
541
+ for key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
542
+ if key in batched:
543
+ batched[key] = torch.nn.utils.rnn.pad_sequence(
544
+ to_tensor_list(batched[key]),
545
+ batch_first=True,
546
+ padding_value=self.pad_token_id if key == "input_ids" else 0
547
+ )
548
+
549
+ # Handle non-sequence data
550
+ if "length" in batched:
551
+ batched["length"] = torch.tensor(batched["length"], dtype=torch.long)
552
+
553
+ return BatchEncoding(batched, tensor_type=return_tensors)
554
+
555
+ def pad(
556
+ self,
557
+ encoded_inputs,
558
+ padding: Union[bool, str, PaddingStrategy] = True,
559
+ max_length: Optional[int] = None,
560
+ pad_to_multiple_of: Optional[int] = None,
561
+ return_attention_mask: Optional[bool] = None,
562
+ return_tensors: Optional[Union[str, TensorType]] = None,
563
+ verbose: bool = True,
564
+ ) -> BatchEncoding:
565
+ """Pad encoded inputs."""
566
+ # This is a simplified version - full implementation would be more complex
567
+ return encoded_inputs
568
+
569
+ # Save/Load methods
570
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
571
+ """Save vocabulary to files."""
572
+ if not os.path.isdir(save_directory):
573
+ os.makedirs(save_directory)
574
+
575
+ vocab_file = os.path.join(
576
+ save_directory,
577
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
578
+ )
579
+
580
+ with open(vocab_file, "w", encoding="utf-8") as f:
581
+ json.dump(self.token_to_id, f, ensure_ascii=False, indent=2)
582
+
583
+ return (vocab_file,)
584
+
585
+ def save_pretrained(
586
+ self,
587
+ save_directory: Union[str, os.PathLike],
588
+ legacy_format: bool = True,
589
+ filename_prefix: Optional[str] = None,
590
+ push_to_hub: bool = False,
591
+ **kwargs
592
+ ):
593
+ """Save tokenizer to directory."""
594
+ if not os.path.exists(save_directory):
595
+ os.makedirs(save_directory)
596
+
597
+ # Save vocabulary
598
+ vocab_files = self.save_vocabulary(save_directory, filename_prefix)
599
+
600
+ # Save tokenizer config
601
+ tokenizer_config = {
602
+ "tokenizer_class": self.__class__.__name__,
603
+ "model_max_length": self.model_max_length,
604
+ "padding_side": self.padding_side,
605
+ "truncation_side": self.truncation_side,
606
+ "special_tokens": {
607
+ "bos_token": self.bos_token,
608
+ "eos_token": self.eos_token,
609
+ "pad_token": self.pad_token,
610
+ "unk_token": self.unk_token,
611
+ "mask_token": self.mask_token,
612
+ }
613
+ }
614
+
615
+ config_file = os.path.join(save_directory, "tokenizer_config.json")
616
+ with open(config_file, "w", encoding="utf-8") as f:
617
+ json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
618
+
619
+ print(f"✅ Tokenizer saved to: {save_directory}")
620
+
621
+ return (save_directory,)
622
+
623
+ @classmethod
624
+ def from_pretrained(
625
+ cls,
626
+ pretrained_model_name_or_path: Union[str, os.PathLike],
627
+ *init_inputs,
628
+ **kwargs
629
+ ):
630
+ """Load tokenizer from pretrained directory or hub."""
631
+ if os.path.isdir(pretrained_model_name_or_path):
632
+ vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
633
+ config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
634
+
635
+ # Load config if available
636
+ config = {}
637
+ if os.path.exists(config_file):
638
+ with open(config_file, "r", encoding="utf-8") as f:
639
+ config = json.load(f)
640
+
641
+ # Merge config with kwargs
642
+ merged_config = {**config, **kwargs}
643
+
644
+ return cls(vocab_file=vocab_file, **merged_config)
645
+ else:
646
+ raise NotImplementedError("Loading from HuggingFace Hub not implemented yet")
647
+
648
+ def get_special_tokens_mask(
649
+ self,
650
+ token_ids_0: List[int],
651
+ token_ids_1: Optional[List[int]] = None,
652
+ already_has_special_tokens: bool = False
653
+ ) -> List[int]:
654
+ """Get special tokens mask."""
655
+ if already_has_special_tokens:
656
+ return [
657
+ 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id}
658
+ else 0 for tid in token_ids_0
659
+ ]
660
+
661
+ mask = [1] # BOS
662
+ mask.extend([0] * len(token_ids_0)) # Token sequence
663
+ mask.append(1) # EOS
664
+
665
+ if token_ids_1 is not None:
666
+ mask.extend([0] * len(token_ids_1)) # Second sequence
667
+ mask.append(1) # EOS
668
+
669
+ return mask
670
+
671
+ def create_token_type_ids_from_sequences(
672
+ self,
673
+ token_ids_0: List[int],
674
+ token_ids_1: Optional[List[int]] = None
675
+ ) -> List[int]:
676
+ """Create token type IDs for sequences."""
677
+ sep = [self.eos_token_id]
678
+ cls = [self.bos_token_id]
679
+
680
+ if token_ids_1 is None:
681
+ return len(cls + token_ids_0 + sep) * [0]
682
+
683
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
684
+
685
+ def build_inputs_with_special_tokens(
686
+ self,
687
+ token_ids_0: List[int],
688
+ token_ids_1: Optional[List[int]] = None
689
+ ) -> List[int]:
690
+ """Build inputs with special tokens."""
691
+ if token_ids_1 is None:
692
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
693
+
694
+ return ([self.bos_token_id] + token_ids_0 + [self.eos_token_id] +
695
+ token_ids_1 + [self.eos_token_id])
696
+
697
+
698
+ class FastChemTokenizerSelfies(FastChemTokenizer):
699
+ """
700
+ SELFIES variant that handles whitespace-separated tokens.
701
+ Uses trie-based longest-match encoding (same as original working version).
702
+ """
703
+
704
+ def _encode_core(self, text: str) -> List[int]:
705
+ """Trie-based encoding for SELFIES with fragment + atom vocab."""
706
+ result_ids = []
707
+ i = 0
708
+ n = len(text)
709
+
710
+ while i < n:
711
+ if text[i].isspace(): # skip literal whitespace
712
+ i += 1
713
+ continue
714
+
715
+ node = self.trie_root
716
+ j = i
717
+ last_match_id = None
718
+ last_match_end = i
719
+
720
+ # Traverse trie character by character (including spaces if part of vocab key)
721
+ while j < n and text[j] in node.children:
722
+ node = node.children[text[j]]
723
+ j += 1
724
+ if node.token_id is not None:
725
+ last_match_id = node.token_id
726
+ last_match_end = j
727
+
728
+ if last_match_id is not None:
729
+ result_ids.append(last_match_id)
730
+ i = last_match_end
731
+ else:
732
+ # Fallback: encode one char as unk or atom
733
+ result_ids.append(self.token_to_id.get(text[i], self.unk_token_id))
734
+ i += 1
735
+
736
+ return result_ids
737
+
738
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
739
+ """SELFIES decoding: join tokens with spaces (preserve original format)."""
740
+ return " ".join(tokens)
741
+
742
+ def decode(
743
+ self,
744
+ token_ids: Union[List[int], torch.Tensor],
745
+ skip_special_tokens: bool = False,
746
+ clean_up_tokenization_spaces: bool = None,
747
+ **kwargs
748
+ ) -> str:
749
+ if isinstance(token_ids, torch.Tensor):
750
+ token_ids = token_ids.tolist()
751
+
752
+ if skip_special_tokens:
753
+ special_ids = {
754
+ self.bos_token_id,
755
+ self.eos_token_id,
756
+ self.pad_token_id,
757
+ self.mask_token_id,
758
+ }
759
+ else:
760
+ special_ids = set()
761
+
762
+ tokens = []
763
+ for tid in token_ids:
764
+ if tid in special_ids:
765
+ continue
766
+ token = self.id_to_token.get(tid, self.unk_token)
767
+ tokens.append(token)
768
+
769
+ return " ".join(tokens) # ✅ preserve spaces
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 gbyuvd
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "training": {
3
+ "batch_size": 16,
4
+ "num_epochs": 1,
5
+ "learning_rate": 5e-5,
6
+ "weight_decay": 0.01,
7
+ "gradient_accumulation_steps": 4,
8
+ "tokenize_batch_size": 100,
9
+ "train_split_ratio": 0.8,
10
+ "val_split_ratio": 0.1,
11
+ "test_split_ratio": 0.1,
12
+ "include_for_metrics": ["input_ids", "attention_mask", "labels"]
13
+ },
14
+ "model": {
15
+ "max_position_embeddings": 512,
16
+ "hidden_size": 320,
17
+ "num_hidden_layers": 6,
18
+ "num_attention_heads": 4,
19
+ "num_key_value_heads": 2,
20
+ "head_dim": 64,
21
+ "intermediate_size": 1280,
22
+ "sliding_window": 16,
23
+ "rope_theta": 10000.0,
24
+ "attention_dropout": 0.1
25
+ },
26
+ "generation": {
27
+ "max_length": 64,
28
+ "top_k": 50,
29
+ "top_p": 0.9,
30
+ "temperature": 1,
31
+ "do_sample": true,
32
+ "num_return_sequences": 3
33
+ }
34
+ }
demo_test_mtpresult.ipynb ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "4ff9650b",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "tensor([[ 0, 379, 1]])\n",
14
+ "tensor([[1, 1, 1]])\n",
15
+ "cuda:0\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "from FastChemTokenizerHF import FastChemTokenizerSelfies\n",
21
+ "# --- Load the tokenizer ---\n",
22
+ "tokenizer = FastChemTokenizerSelfies.from_pretrained(\"./selftok_core\")\n",
23
+ "\n",
24
+ "# Test it\n",
25
+ "out = tokenizer(\"[C]\", return_tensors=\"pt\")\n",
26
+ "print(out.input_ids) # ← Attribute access works\n",
27
+ "print(out.attention_mask) # ← Also works\n",
28
+ "out = out.to(\"cuda\") # ← Moves all tensors to GPU\n",
29
+ "print(out.input_ids.device) # ← Should be cuda:0"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 2,
35
+ "id": "d16aeaf7",
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "Model has 9,854,851 trainable parameters.\n",
43
+ "Input shape: torch.Size([2, 32])\n",
44
+ "Logits shape: torch.Size([2, 32, 782])\n"
45
+ ]
46
+ }
47
+ ],
48
+ "source": [
49
+ "import torch\n",
50
+ "from ChemQ3MTP import ChemQ3MTP\n",
51
+ "# --- Initialize model from scratch ---\n",
52
+ "\n",
53
+ "model = ChemQ3MTP.from_pretrained('./enhanced-qwen3-final')\n",
54
+ "\n",
55
+ "# --- Print model parameter count ---\n",
56
+ "def count_parameters(model):\n",
57
+ " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
58
+ "\n",
59
+ "print(f\"Model has {count_parameters(model):,} trainable parameters.\")\n",
60
+ "\n",
61
+ "# --- Quick forward pass sanity check ---\n",
62
+ "batch_size, seq_len = 2, 32\n",
63
+ "dummy_input = torch.randint(\n",
64
+ " low=0,\n",
65
+ " high=len(tokenizer),\n",
66
+ " size=(batch_size, seq_len),\n",
67
+ " dtype=torch.long,\n",
68
+ ")\n",
69
+ "\n",
70
+ "with torch.no_grad():\n",
71
+ " outputs = model(dummy_input)\n",
72
+ " logits = outputs.logits\n",
73
+ "\n",
74
+ "print(f\"Input shape: {dummy_input.shape}\")\n",
75
+ "print(f\"Logits shape: {logits.shape}\") # should be [batch_size, seq_len, vocab_size]\n"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 3,
81
+ "id": "105b47a0",
82
+ "metadata": {},
83
+ "outputs": [
84
+ {
85
+ "name": "stdout",
86
+ "output_type": "stream",
87
+ "text": [
88
+ "[Branch2] [=Branch1] [Branch1] [C] [=Branch1] [C] [=O] [N] [C] [C] [N] [C] [=Branch1] [C] [=O] [C] [N] [C] [=Branch1] [C] [=O] [NH1] [C] [=Ring2] [Ring1] [=Branch1] [=C] [Branch2] [Ring1] [C] [C] [C] [O] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [Branch1] [=Branch2] [N] [C] [C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [O] [C] [=C] [Ring1] [=C] [Ring1] [#Branch1] [C] [Branch2] [Ring1] [O] [C] [C] [O] [C] [=N]\n"
89
+ ]
90
+ }
91
+ ],
92
+ "source": [
93
+ "# Generate SELFIES\n",
94
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
95
+ "model.to(device)\n",
96
+ "input_ids = tokenizer(\"<s>\", return_tensors=\"pt\").input_ids.to(device)\n",
97
+ "gen = model.generate(input_ids, max_length=256, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n",
98
+ "print(tokenizer.decode(gen[0], skip_special_tokens=True))"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 4,
104
+ "id": "b041d311",
105
+ "metadata": {},
106
+ "outputs": [
107
+ {
108
+ "name": "stdout",
109
+ "output_type": "stream",
110
+ "text": [
111
+ "C1(=O)NCCNC(=O)CNC(=O)[NH1]C1C(CCOS(=O)(=O)C=C2C=CCNCCN(C)C)(O)OC=C2\n"
112
+ ]
113
+ }
114
+ ],
115
+ "source": [
116
+ "# Manually convert it to SMILES\n",
117
+ "import selfies as sf\n",
118
+ "\n",
119
+ "test = tokenizer.decode(gen[0], skip_special_tokens=True)\n",
120
+ "test = test.replace(' ', '')\n",
121
+ "print(sf.decoder(test))\n"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 11,
127
+ "id": "f1608fa0",
128
+ "metadata": {},
129
+ "outputs": [
130
+ {
131
+ "name": "stdout",
132
+ "output_type": "stream",
133
+ "text": [
134
+ "C=1=NC2=CC=CC=C2N=1\n"
135
+ ]
136
+ },
137
+ {
138
+ "data": {
139
+ "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAEsASwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigBskiRRtJI6oijLMxwAPc02G4huY/MgmjlTON0bBh+YrnfiJ/wAk71//AK83/lXk/wAK9ZuvB2oabYak/wDxJ/EMfm2sh4WOcEoR+OAD9UNerhcseJwk68Je9F6Lvpd287a28iJTtKx71PcQWsfmXE0cSZxukYKM/U08EMAQQQeQRXmfx2/5J6n/AF/RfyavQtL/AOQRZf8AXBP/AEEVzVMKoYWGIv8AE2relv8AMal7zRbooorjKCiiigAooooAKKKKACiiigAooooAKKKiurq3sraS5up44IIl3PLKwVVHqSeBQBLWVr/iXR/C+nm91m/itYei7zlnPoqjlj7CuOn8e6v4qnksfh/pwuI1YpLrV6pS1jPfYOsh+n5EVpeH/hxYadqA1nW7qbXteOCb29GRGfSJOiD07jtigDBk8T/EnxG7aj4X8PWthpUI3xJqx2zXw9AuRsBGCM4H+0elb3hv4jafq94NI1a2l0PXxw2n3vylz6xscBwe2OfbvXaVj+IfC+jeKrD7HrNhFcxjlGPDxn1Vhyp+lAGxRXmv2Xxr8P8AmyebxX4fT/l3lYC+t1/2W/5agenXsMV1nhnxlofi23aTSrwNLHxNbSDZNCe4ZDyOeM9PegDeooooAKKKKACiiigAooooAKKKKACiiigAooooA5n4if8AJO9f/wCvN/5VxWneEU8YfArSbNAFvoYWmtJOmJA7cZ9D0/I9q7rx3bT3ngTWra1hknnktHVI41LMxx0AHU1X+G9nc2Hw90e1vLeW3uI4mDxSoVZTvbqDyK9nD4mVDAKdN2kqia/8Bf4GbV52fY8k8T+LZPE3waWC/JXV9Ov4re8R+GJCuA5Hvg59wa9jh8T6Hpi6dpl/qlra3klpHIkcz7NykYByeOoPGc15X8YPh/qLav8A2zoFlc3Md+QLy3toy5Eg5D7R2Pr6j/ar146Hp2raJaWuradb3KrAgKXEQYqdo9ehrtzGeDlhqMo/DJydla8W7aeid7baEwUuZmsrK6hkYMpGQQcgilrhW+GkOnM0vhbXNS0KTORFHKZrcn3jfr+dJ/aXxB0Ef6fpNj4gtl/5bWEnkzY9SjcE+y15H1OnU/gVE/KXuv8AH3f/ACY05mt0d3RXHaf8TfDd1cC1vp59Hve9tqkRgYfifl/WuuiljniWWGRJI2GVdDkEexFc1bDVqDtVi16jTT2H0UUViMKKKKACiiigAorm/E/jnRPCgSK9naa/l4gsLVfMuJiegCD19TgVzP8AY3jHx98/iCd/Dmgv00uzfNzOv/TWT+EH+6PoR3oA0tb+JFpb6i2i+G7OXxBrg4NvaH91D2zLL91QP/14qnbfD7UfElzHqPxA1EagynfFpFqSlnAfcdZD7n6ciuy0Pw/pPhvT1sNHsIbO3X+GMcsfVj1Y+5JNVPEvjDRPCVqs2rXqxu/EVug3zTH0RByee/T3oA2YIIbaBIIIkihjUKkcahVUDsAOgrkvEfxF0zRr7+yNPgm1rXm4TTrEbmU/9NG6IPXPPtWN5Xjb4gcztN4T8Pv/AMskP+n3C+56RA/n9RXY+HPCui+FLH7Jo1jHbq3MknWSU+rMeSaAMXw3pXjC51hdc8T6sluAjLFo1jjyYwe8jHl2Ht0PQ4OK7OiigArlPE3w/wBH8R3K6ghm03WY+YtTsW8uZT/tEfeHse3cV1dFAHmy+LPE3gdhB41s/wC0dKBwuu6fETtHrPEOV+o4+td9pmqWGs2Ed9pt3Dd2sgyssLhgfb2PtVplDKVYAgjBB71wWp/Dc2V/JrHgnUDoGpv80kCjdaXJ9Hj6D6r09M0Ad9RXAad8R20++j0jxxp50LUXO2O5J3WdwfVJP4fo3T17V3ysrqGVgykZBByCKAFooooAKKKKACiiigAooooAKa7rGjO7BUUEszHAA9TTq4DxZeXHizXR4I0qVkgCiTWbqP8A5ZRHpED/AHn/AJevNdGGw7rzteyWrfZd/wDLu9BSdkWfDGran4s8SXWuQ3EkHhu3VrWzixgXjZ+aU57AjA//AFg9tUFnZ2+n2UNnaRLDbwII441HCqBgCuI1K+u/GPjFNC0q5mg0nSZVm1O7hcqZJQcrArD3GW/+tzu4rFVW4LlhFfcl37tv72+wvhXmd9RRRXAUFFFFAFTUNLsNWtzb6jZW93D/AHJ4w4/WuQufhhZW2+Twzq+peHp2bfi0mLxE8/ejbII56ZFd1RWsa9WMHTUnyvdX0+4Vle5559s+Jvh0f6TYaZ4otV/5aWr/AGW5I9Sp+Q/Ras2PxZ8OSXC2esC80C+P/LDVoDD+T/dx7kiu6qtfafZanbNbX9pBdQN1injDqfwPFZDJLa6t7y3S4tZ4p4XGVkicMrD2I4NS1wNz8JtGgme68OXuo+HLtjktp1wwjY/7UZJBHsMVi65q/wASPBcdhbTX+iazHqV7Hp9rdzQNDKksmdpdE+UrwenNAHpGs67pfh7T3v8AV76Gztk6vK2Mn0A6k+w5rhf7e8XePf3fhm2bQNDfg6vex5nmX1hi7D0Y/hg1o6N8N7ddQTWvFV7J4h1ocrJcj9xB7RRfdA9/x4ruCQqlmICgZJPagDmvDHgTRPCrPcW0UlzqUvM+o3bebcSk9csen0GK3NR1Kx0ixkvdRu4bW1jGXlmcKo/E/wAq4vU/iR9tvpNH8E6edf1NflkmQ4tLc+ry9D9B19c0mnfDh9Rvo9X8c6gdd1BDujtcbbO2Pokf8X1br3HegCs3i/xJ43c2/giy+xaYTtfXtQjIUj1hiPLn3PHqBW34a+Huk+H7ptSnebVdbk5l1O+bzJSf9nPCD2HbjJrrFVUUKqhVUYAAwAKWgAooooAKKKKACiiigAooooAq6jpljq9jJZajaQ3VrIMPFMgZT+B7+9cC3hDxJ4Ic3Hgi+N7pgOW0HUJCVA9IZTyh9jx6k16RRQByXhr4haT4hum02dZtK1uPiXTL5fLlB/2c8OPcducCutrC8S+D9D8W2qw6tZLI6cxXCHZNCfVHHI57dPauP0i68U+DfHej+FtV1WPW9I1cXH2K5nBF1B5Ue8hyOG6gZOSevGMUAem0UUUAFFFFABRRVHWNXstB0i61TUZhDaW0Zkkc+noPUk8AdyaAMfxp4mk0DTorfT4vtGtag/kWFuP4nPVj/sr1P4VP4Q8Mx+GNG+ztKbi+ncz3t03LTTNyxz6dh/8ArrD8Eadd61eyeONct2hvb2PZYWrnP2S27f8AAm6n6+5FdP4h16z8NaJcanek+XEvyxr96R+yqPU//X7V3ymvZxw1DXms2+76L0X4u77E21uzE8b+IbuyW20DQ8Pr+qEpB/07x/xTN6ADOPf1xitjwz4etPC+hwaZaZYJ80srfemkP3nb3J/oO1YvgjQLyJrnxLryg67qmGdD0tYf4YV9MDGff6ZrsaMVONOH1ak7pfE+8v8AJbL5vqEVf3mFFFFcBQUUUUAFFFFABRRRQAV5/wDFT7ngz/sabH/2evQK8/8Aip9zwZ/2NNj/AOz0AegVxviLwTeeLNZYatrlwPDqqu3SrUeV5rfxea4OWXPYY/DHPZUUAU9M0qw0WwjsdMs4bS1jGFihQKB7+59+tXKKKACiiigAooooAKKKKACiiigAooooAKhvJ/stlPcbd3lRs+3OM4GcVNVPV/8AkC33/XvJ/wCgmgCh4R18+KPCmna21sLY3ke/yQ+/ZyRjOBnp6VzXi3/kr3w6/wC4n/6TrVz4S/8AJK/D/wD17n/0Nqp+Lf8Akr3w6/7if/pOtAHoFFFFABRRRQAV5hJ/xdHxj5Q+fwhoU/zn+HULsdvdE/I+4PGh481q+1K/t/A3h6bZqmoJuvLlefsVr0Zz6M3QD37ZBrr9D0Wx8O6La6Tp0IitbZAiL3PqT6knJJ9TQBed0jjZ3ZURRksTgADvXmmjo3xL8Xr4iuFJ8MaPKU0qJhxdzjhpyO6g8L/TkVN4zvrnxfr6+AdHmeOHaJdcvI/+WEB6RA/33/l6jOO/sbG20ywgsbKFYba3QRxRoOFUDAFAFiiiigAooooAKKKKACiiigAooooAK8/+Kn3PBn/Y02P/ALPXoFef/FT7ngz/ALGmx/8AZ6APQKKKKACiiigAooooAKKKKACiiigAooooAKKKKACqer/8gW+/695P/QTVyqer/wDIFvv+veT/ANBNAHL/AAl/5JX4f/69z/6G1U/Fv/JXvh1/3E//AEnWrnwl/wCSV+H/APr3P/obVT8W/wDJXvh1/wBxP/0nWgD0CiiigArD8Ya+/hbwlqOtJaPdvaRb1hTuSQMn0UZyT6A1uU10SSNkdQyMCGVhkEHsaAOK+GWiLa+Hhr91cre6vroW8vLsc53DKxr6KoOMeufoNLx/4jn8KeCtR1e1hEtxEqrHu+6rMwUM3sM5/CuX0Z3+Gvi5fDlwxHhjV5S+lSseLWc8tbk9gTyv9eSPR7u0gvrOa0uolmt5kMckbjIZSMEGgDnvAnhmDw14eRVuRe3t6ftV7fZ3G5lfktnuvPHt9TXT15t4Ru5/A/iT/hBNUld9Pn3S6FdyHO6Pq0DH+8vb2+oFek0AFFFFABRRRQAUUUUAFFFFABRRRQAV5/8AFT7ngz/sabH/ANnr0CvP/ip9zwZ/2NNj/wCz0AegUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFU9X/5At9/17yf+gmrlU9X/AOQLff8AXvJ/6CaAOX+Ev/JK/D//AF7n/wBDaqfi3/kr3w6/7if/AKTrVz4S/wDJK/D/AP17n/0Nqp+Lf+SvfDr/ALif/pOtAHoFFFFABRRRQBkeJ/Dlj4r8P3OkX6nypl+SRfvROPuuvuD/AId65/wD4jvp2u/C3iFgPEOk4WR+13D/AATL65GM+/1wO3rivH3hq9vVtfEfh/CeI9IzJb+lzH/HA3qCM49/TOaANTxn4Vg8XaA9i0ht7uJhPZXS8NbzLyrg/wA/aqXgPxVPr1hcafq0Yt/EOlv5GoW/TLdpF/2WHI/wxWr4V8SWfizw9batZZVZRtkib70Mg4ZG9wf8e9c5430K7stUtfG2gof7T09Cl5Ag/wCP216shHdh1B/ngCqhBzkordgd5RVHRtXs9e0i21OwlEltcJvQ9x6g+hByCPar1EoyhJxkrNAFFFFSAUUUUAFFFFABRRRQAV5/8VPueDP+xpsf/Z69Arz/AOKn3PBn/Y02P/s9AHoFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABVPV/wDkC33/AF7yf+gmrlU9X/5At9/17yf+gmgDl/hL/wAkr8P/APXuf/Q2qn4t/wCSvfDr/uJ/+k61c+Ev/JK/D/8A17n/ANDaqfi3/kr3w6/7if8A6TrQB6BRRRQAUUUUAFFFFAHnmuJJ4B8UHxNbKx0HUXWPVoVGRBIeFnA9Ozf1yMegxyJLGkkbq8bgMrKcgg9CDUd1awX1pNa3USywTIY5I2GQykYINcL4Wup/B+vnwVqcrPZyhpdFupD9+PqYSf7y9vb04Fei/wDa6N1/Egvviv1j/wCk+hHwvyIZP+LceLPOHy+Ftam/eD+Gxuj39kb9PbHPo1U9V0u01rSrnTb+IS2twhSRT6eo9COoPqK5HwXql3pGpS+CtclL3lom/T7l/wDl7tu3/Al6Ee3sTRU/2ul7VfHFa+a/m9Vs/Kz7gvdduh3VFFFecWFFFFABRRRQAUUUUAFef/FT7ngz/sabH/2evQK8v+LGvaRHceE7R9Us1uLfxHZ3E0RmXdFGu7c7DPyqMjk0AeoUxZY2kaNXUumCyg8rnpkUsciSxrJG6ujgMrKcgg9CDXIeJfAFvq+p/wBu6TqFzo3iBVCi+tmyJABgLIh4daAOxorzm38e6t4WnjsPiBpwtkYhItas1L2kp7bx1jJ9/wBBXoNtc297bR3NrPHPBINySxOGVh6gjg0AS0UUUAFFFFABRRRQAUUUUAFFFFABVPV/+QLff9e8n/oJq5RQBxfwl/5JX4f/AOvc/wDobVT8W/8AJXvh1/3E/wD0nWu/ACgBQAB2Fec+KL21n+NXgG0iuYpLm3GoGaJXBaMNbjbuHUZwcZ9KAPR6KKKACiiigAooooAKwfF3hqHxRojWhkMF3Ewms7leGgmX7rA/z9q3qK0pVZ0pqpB2aE1dWZy/grxNNrljPZanGINc05/Iv4OnzdnX/ZYcirviDwxaeIHsJ5ZZra7sLhZ7e5gIDpz8y8jow4IqvrVno+iXtx40uIZxc2dmySmBsebHweVyAxGOMn+QwvhHxppPjWwmutLMy+TJskinUK68ZBIBPB5wc9jXbNTu8XhotRW/ZN7r07X6OxK/lkdFRWH4r8VWHg7Rxqeox3EkBlWLFuoZskEjgkccetbEEy3FvFOgIWRA4B64IzXE6U1BVGvdeifpuVdXsSUUdBk1y+r/ABD8L6NJ5M2qRT3WcC2tP30hPphc4P1xTpUKtaXLSi5PyVwbS3Ooorg/+Em8Z67xoPhYadA3S71qTYf+/S/NSSfD3UddiI8WeKtQvUYgtaWJ+ywY/ukLyw574Nb1MG6UW6k4p9r3f4XS+bQlK+xq678RPCnh1jFf6zbm5Bx9mgPnSk+m1ckfjisT/hM/GHiAY8MeDZbaFvu3uuv5C/Xyh85HuDXUaH4Q8PeGkC6Po9paMBjzEjzIR7ucsfxNbdcZR55/wgHiHXlz4t8ZXssTdbHSlFrD/ulh8zj64rb074c+D9KspLS28PWBjkXbIZohK7j3Z8n9a6iigDzeTwVr/g2RrrwHf+bY53PoOoSFoj6+U55Q+xOM9T2rY8OfETS9bvTpN9DNo2upw+nXw2OT/sN0cemOe+K7CsXxH4T0TxZZfZdYsY5wv+rlHyyRH1RxyP8AOaANa4t4bu3kt7mGOaGRdrxyKGVh6EHgivP7n4f6l4auZNR8AakLHcS8uj3ZL2cx77e8ZPqPpwKh2eNvh/8AcM3izw8n8Lf8f9uvsekoH5/QV13hrxfoni20M+kXqSsn+tgb5ZYj6Oh5H8vQ0AYmhfEizutRXRfEVnL4f13oLW7P7ub3ik+6w/yM129Zmu+HtJ8S6c1hrFhDd256LIOVPqp6qfcVxH9keMvAHzaFPJ4l0FOum3b4uoF9IpP4gP7p+gHegD0qiuc8L+ONE8WI6WFw0V7FxPY3K+XPCR1DIf5jIro6ACiiigAooooAKKKjuLiG0t5Li5mjhhjXc8kjBVUepJ4AoAkrM13xFpHhnT2vtYv4bS3HQyHlj6Ko5Y+wFcbc+P8AU/EtzJp/w/04Xu1tkusXYKWcJ77e8hHoP1FX9C+G9naaguteIbyXxBrvX7VdgeXCfSKP7qD/ACMUAZf9r+MvH/y6FBJ4a0F+upXaZup19Yo/4Qf7x+oPauk8LeAtA8ImSbT7Zpb+XJmv7lvMnlJ5JLHpnuBgGumooAKKKKACiiigAooooAKKKKAOZ+In/JO9f/683/lXjnh4z/D+08M+MrcO2k6lF9m1ONecHc2Gx9ACPdSP4q9j+In/ACTvX/8Arzf+VYvgnRrTxD8GNO0q9TdBc2zofVTvbDD3BwR9K+jy/Exw+Xt1FeEp2kvJx/NbrzRjON56djP+N80Vz8NoJ4ZFkiku4XR1OQylWIIrobyXxe9rp1p4dtdNS3a0jaS+vZGOxscqI15Jxg5PHNeJeIdVvNM8D33gTWWP27Sr+NrZz/y0hIbp7DII9m9q+kdL/wCQRZf9cE/9BFVj6LwOFpRaUlzTtfZpqNn9wRfNJnH/APCu7jVvm8V+JdR1YHrawt9mt/oUTk/XIrp9H8OaNoEXl6VpltaDGC0cYDN9W6n8TWpRXiVcbXqx5ZS93stF9ysjRRSCiiiuUoKKKKACiiigAooooAK5HxL8PNJ1+7Gp2zzaTracx6lYtskz/tgcOPrzjjIrrqKAPN08Y+I/BLrb+ObH7XpwIVNe0+MlAP8AptGOUPuOOwB6132n6jZatYx3un3UN1ayjKSwuGU/iKssqujI6hlYYIIyCK4DUPhxJpl7Jq3gXUf7Dv3O6W0I3WdyfRo/4fqvTsO9AGz4o8B6L4pdLm4jktNTi5g1Gzbyp4iOnzDqPY5/Cuc/4SHxb4C/d+KbZtd0ReBrFjHiaJfWaIf+hD9TV3S/iQLS/j0fxpp7eH9UY4jlkbdaXHvHL0H0PTpnNd4CGUEEEEcEd6AKOj63pniDT0v9Jvoby2fo8TZwfQjqD7Hmr9c5Z+BtC03xQfEGnW72V26Ms0dtIUhmz3eMcEjn8TnrXR0AFFc74o8baH4SiT+0bktdS8QWUC+ZPMewVB/M4HvXMf2Z4z8ffNrE0vhjQG6WNq+bydf+mj/wA+g56gjvQBqa98R7Gy1A6NoNpLr+u9PslmcrEfWWT7qD17jvis+38A6r4ouI7/4gaiLpFO+LRbNilpEe249ZCPf9RXY6D4b0fwxp62OjWENpAPvbB8zn1Zjyx9ya1KAIra2gs7aO3tYY4IIxtSONQqqPQAcCpaKKACiiigAooooAKKKKACiiigAooooAyvE2kvr3hnUdKjlWJ7uBoldhkKT3NQ+EdDk8N+FdP0eWZZpLVCpkQYDZYnp+NbdFbe3n7H2F/dvf52sKyvc89+JHwyj8cS2l5a3MdnfwgxvI6EiSPqAcdwen1Nd5aQm2soICQxijVCR3wMVNRV1cXWq0YUZu8YXt5XBRSdwooormGFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAU9U0nT9asJLHU7OG7tZPvRTIGH19j79a4I+FfFHgUmbwZeHVNJXltD1CXlB6QSn7v0PH1Nek0UAcJp3xd8KXFtdf2ndvo19Zj/SbLUEKSoR1Cj+P6Lk+wrPXxN4r8fAJ4RtTo2ivwdav48ySDv5MX/sx4+hrtNV8K6Brl5b3eqaPZ3lxb/6uSaEMR7c9R7HitcAAAAYA6CgDl/DHgLRvDEr3kSS3urS8z6leN5k8h7/ADH7o9h+Oa6iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAP/2Q==",
140
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEsCAIAAAD2HxkiAAAi4klEQVR4nO3deVxU9foH8IdhBxdQLA3JpVTE5WK4Y3nNel2X6WWloqajaTZR1hj3mlhpk+XCNavRNMVsGXfJROcm6g/NBTUlTTMUlxAEKU0EjU2Wmef3x3caCY0Zzpwz33OG5/3yD4HvnPOwfGbOnPN8v8cDEYEQwo+KdwGENHQUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFELCTWpq6ksvvbRr1y7ehXDmgYi8ayANztWrV0eOHHnkyBH2Ybdu3ZKTkx966CG+VfFCr4TEpSwWyxtvvBEaGsoSGBgYCAA///xzx44d4+LizGYz7wI5oBAS10lPT4+Ojl68eLHFYgkMDPzss89KSkq2bt3arFkzi8ViMBi6deu2e/du3mW6HBIivStXrmg0Gg8PDwAIDQ1dsGCB2WyuOWDZsmXt2rVjf5NqtfrSpUu8SnU9BbwnzMwEk8n6/wEDIDq6rsFJSZCdDQAwdCh07y55bcSu8vLypUuXzps3r6SkxN/fX6fTzZ49u1GjRnePrKysXLFixZw5c4qLi318fGJjY+fNm9e4cWPX1+xqvJ8F7NuwAQGs/1q2xJs36xo8ZIh15OrVrqqP/D2TydS2bVv2l6ZWq7Ozs+0+JD8/X6vVqlQqAHjggQcSExNrvWa6H4WFEAB1uroGUwhl4sSJE48++iiLX48ePQ4ePFivh6enp/fv3589vGfPnocPH5aoTjlQXghVKjx69G8HUwi5Kygo0Ol0np6eANC8eXODwVBdXS1gOxaLJSkp6cEHHwQADw+P0aNHX758WfRq5UBJIWzUyPqfnj3x736tFEKOKisrDQZD06ZNAcDb21un092s+82DA0pKSvR6vZ+fHwAEBgbq9fry8nJRqpUPJYXwqacwKsr6/08+ufdgCiEvqampERER7ADyiSeeOHPmjIgbz83N1Wg0bONhYWFGo1HEjXOnpBCOGIG7d1v/36QJXrlyj8EUQtc7d+7csGHDWEI6der07bffSrSj7777rvufp7wHDRr0008/SbQjF1NYCBFx2DDrhzEx9xhMIXSlwsLC+Ph4Hx8fAAgODk5ISKioqJB0j2az2Wg03nfffQCgUqk0Gs21a9ck3aMLKC+EFy+ir6/1M3c/51IIXYNvGGqGPygoyAXhl5TyQoiIM2ZYP9OmDZaU/GUwhdAFxDosrK6utlgsgsuoeRjcsWNH6Q6DpabIEN66ha1aWT/51lt/GUwhlJSIJ0j2798fGRnp/CkWSU8IuYYiQ4iI69ZZP+njgzV/7BRCiYh4qSAnJ2f06NEsNv369XO+NikujbiSUkOIiIMGWT8/cCDaDmoohKJjF83DwsKcv2heWlqq1+v9/f0BICAgID4+vri4WKw6xWoScD0Fh/D8+TtnaNavt36SQiiumu1jvXr1Etw+xpLcpk0bW5JzcnLELZX58ccfa7bLHThwQIq9iEvBIUTE+Hjrl0JDrWdoKIRiEbGR+vjx49F/Tn6JiopKS0sTt9S7CWgc50jZISwtxbZtrV/V6xHvFcIxY/CFFzAlBZV8EtulysrKEhIS2BwiHx8fnU73xx9/CNvUr7/+aktyq1atEhMTXXaIyL4LNmfK398/Pj5e8HchNWWHEBGTk61fDQjAvLzaISwsRC8v62eCgnDCBExOxrIy134DimIymUSZXMtOljRp0sR2suTWrVviluqIWpOJjUajMxdFJKL4ECKiWm0dMGnSPV4Js7LQYMDoaPTwsH7J3x/VajQakcdfhXydPHnyscceY/Hr3Lnzrl27BG/KZDLZVm1Sq9W//PKLiHUKcOzYsX79+tne2R45coRvPbW4QwizstDf3zrLqU2bv31PmJNTO41+fqhWY2Ii/v67xN+DvN24ccN2XrFZs2bOnFfMzMwcOnQo+3MPDw9PSUkRt1TBLBaL0Whs2bIlOzOk0Wh+++033kVZuUMIEfHdd/8y57DuEzO5uZiYiGr1nSNVT0+MjkaDAWXze3ERdtAYFBRkO2gsKioStimWZC8vL9ZHajAYqqqqRC1WBOxqp6+vr6wmRrlJCG/fxk6dHA2hzfXraDSiWo3e3rXTmJ8v3jcgV6mpqV26dLH1mmRkZAjbTlVVVWJiYosWLQDAy8tLq9X+Lu9Di4sXL9q6BR5++OGkpCS+9bhJCBHx//6v3iG0uXHDmkbbhUeVCqOiUK/Hixedq16Wzp8/r1arbV2X//vf/wRvas+ePd26dWObevzxx0+fPi1inZLau3evTCp3nxAi4qhRAkNoU1qKJhNqNHdm8QNgRATq9ZiZKWSDclNUVBQfH8+Ox9j8g9u3bwvblNxeTwSQyWu4AkK4ZQsGB2NwMD73nJ2ReXnYqpV18Nq1Tu20rMyaxiZNaqfx+HGntszL3ZOPrl69KmxTd7+zEpxkOSgsLOT7blYBIeSrvBxNJtRqsUWLO2ls3x51OkxLQ/ldc7q3ffv2/eMf/2CvWgMHDjx58qSw7cj5HKOTap7X7dSpkyvP61IIHVVRgTt34tSpGBLylzS+8QYeO2aW4SVgJi8vz3a1unXr1s5crT527Fjfvn3Zn2nv3r2///57cUuVAy5XOOUbwosXUa3GvDzeddyluhrT0lCnw9BQaxT79s0PCwvTarUmk0k+5+XZlAU2+SggIECv15cJ7RVSRN+JWFzf6yPTEFZUYM+eCIBaLe9S/p7ZjGlpGBeHw4evhD+1bNkyNjY2NTWVYxrvXrFT8JSFuzswRZx8JGes65U1MISEhEg6MUqmIWQLWLRvb2fRe/nIyMjQ6/Xh4eG2NAYHB2s0GpPJ5OKTFrWmLBw6dEjwppQ1F0EKJ06cGDBgAPsJPPLIIxLN/5BjCHfvRpUKvbxQZi1+DmFpjIqKsqUxICBArVYbjUapX0PunrIgePJRzUXsH3nkkfouYu9mTCYTmwnJnoxEnwkpuxBeu2ZdP2bhQt6lOCcrK8tgMERHR7O3UuxwjqVR9PcYFRUVtrcxbPKR4F0od366pEpLS22H5aKvCSCvEFos1ikRAwf+7UL3ipOTk1MrjX5+fmq1OjExUZRLwyaTqX379rbn6aysLGHbUfpKLS4g4qnmmuQVwg8+QABs0cI9Wzfz8vISExPVajW7LgwAnp6e0dHRBoNB2NW2zMzMIUOG2KYs7Ny5U3BtbrBmmcscPXq0T58+7GfVp08f5y/VyCiEx4+jjw96eOD27fYHl5bikiVKfbUsKCgwGo1qtZotX8taWFgar9xzcf+71JyywCYfCT4ZW2sR+x07dgjbToPC2o/uv/9+59uPUD4hLCmxToN4/XWHxr/0EgLglCkSlyWxwsJClkbWAsZ+o1FRUXq9/sKFC/d8CGt3DAkJsbU7Xr9+XfDeXbyIvZspLi62te81atRo9uzZwlbQkEsIJ05EAOzWDR2Z3vXNNwiAvr546pT0lblEaWmpyWTSaDQ17yMdERGh1+sza3SO79mzp2vXruyrgwcP/vnnn4Xtjj2Rs8Zlt7mjAy8XLlywNbJ7e3vHxcXVdwuyCOHmzQiAgYEOzVTIy8PmzREAly2TvjKXKykp+frrr8eOHVvzXu3du3ePi4t78skn2YcdOnQwmUyCd+Gu9zbiKyUlhTUnsYv7u3fvdvyx/EOYlWWdqfD55/YHm83WNX+HDVNM87Qw5eXlJpNJq9WyeQ8eHh7e3t5OTllw77v8cVdWVhYTE8POnXp6ejr+QM4hrKrCvn0RAEeNcmi8Xm9dZVTo+yDlqays3LVrF7twl5ubK2wjDeF+tzJx/vx59jTn+EM4h3DmTATAsDAsLLQ/+OBB9PRElQr37pW+Mplhp08EvAbevYi94CQTBykphPv2WdvTHFlbvbDQupLanDnSVyY/wkKYnp5ec6k/wYvYk3pRTAh//93anjZ/vkPjY2IQAHv3xspKiSuTpfqGMD8/39bb4eQi9qS+lBFCiwWfeqoe7WnLl1uX0G54ffxWjodQxEXsiTDKCOGHHyIABgejI/fYysiwru27caP0lcmVgyEUaxF74gwFhPD0afTzQw8P3LbN/uDycuzeXe6ze13AkRCazeZevXoBQGRkpCJuCaY4xcXF2dnZBQUFdQ+TewhLSjA8HAFQp3NofGysdZmz0lKJK5M3B18JDx8+7Mo7HzU0n3/+OQBMnjy57mH1DaEXuNarr8K5c9C1KyQk2B+cnAwrV4KvL6xfDwEB0henfP3797fd05MohcqVO/v6a/jqKwgIgKQk8Pe3M/jKFXjxRQCAxYshMlL64gjhxHUhvHTJGqolS6BzZzuDLRaYOBFu3IBhw2DaNBdURwg3LgphdTVMmAC3bsHIkTB1qv3x770H+/ZBaCgYjfDnfHRC3JOLQjhnDnz/PYSFwapV9gcfOgTz5oFKBUYjhIRIXxwhXLkihAcOwAcfgJcXbNwIzZrZGXzzJkyYAGYzvPUWDB7sguoI4UzyEF6/fn3x4mQvL9Dr4c/lMOsSGwuXL0Pv3vDOO1KXRogsSBtCRJw8efK33z779NNz33zT/viVK2HzZmjaFDZtAm9vSUsjRC6kDeGSJUt27NgRHBy8aNFkT087g8+ehf/8BwBgxQr4s/WKEPcnYQh//vnnN998EwC+/PJLdl+EOpSXV40dC2Vl8MILMG6cdEURIjtShbC0tDQmJub27duvvvrqiBEj7I6fMeP1gICEqKjqjz+WqCJCZEqqtjWdTnfu3LkuXbosWrTI7uBt27Z9+umnvr6+x44Nbdz4HxKVRIg8SfJKuGXLli+++MLPz2/Dhg3+9vrT8vPzp06dCgCLFi2y3U2WkIZD/BDm5eVptVoAWLJkiW1pvb9jsVgmTpx448aNoUOHvvbaa6IXQ4j8iRzC6urqsWPHFhUVPfvssyyKdZs3b9533313//33f/nllx7Un0YaJJFDqNfrjxw50rp1688++8zu4PT09Hnz5qlUqnXr1rFl/QlpgMQM4YEDB/773/+qVKq1a9c2s9efdvPmzTFjxlRVVc2aNeuJJ54QsQxClEW0EBYVFU2cONFsNuv1+n/+8592x7/88ss5OTm9evV69913xaqBECUSJ4SsPS03N/fRRx99++237Y5ftWrVpk2bmjZtunnzZm/qTyMNmzgh/OSTT7Zv3x4cHLx27VpPe/1pZ8+ejYuLA4BPP/20HfWnkQZPhBBmZGTMmjULAFauXNmmTZu6B1dUVIwfP76srGzKlCnPPfec83snROmcDSG7E015efkrr7wSExNjd/yMGTNOnTr18MMPGwwGJ3dNiHtwNoTTp0/PzMzs0qXL4sWL7Q5OSUlZvny5r69vUlJSzfvvEdKQORXCb775ZvXq1Y63p02aNAkRExISevTo4cx+CXEnwkNoa0/7+OOPHWlPmzRpUkFBwZAhQ6ZPny54p4S4H4EhrK6uHjduXGFh4TPPPBMbG2t3/IIFC/bu3UvtaYTcTWAI586de/jwYcfb09577z3WSdOyZUtheyTEXQkJ4cGDBxcuXKhSqdasWdO8efO6B9+6dWvs2LFVVVUzZ8588sknBRVJiDurdwiLioo0Go3ZbJ4zZ86gQYPsjn/55Zezs7N79uw5d+5cQRUS4ubqHULWnjZgwIA5c+bYHbx69eqNGzc2atRo/fr17L5ChJBa6hfC5cuXb9++PSgoaN26dXbb0y5evPjvf/8bAFauXNmxY0fhNRLi1uoRwm3btrGrCw62p40ZM6a4uPj5558fP368UzUS4tYcDWFZWZlWqzWbzZGRkWPGjLE7Pj09/ezZsw8++OAnn3ziXIWEuDmX3p+QEHI3R0MYEBCwatUqT0/PU6dObd682e743r17R0RE5Obm0vJNhNStHq+ETz/99JIlSwAgNjb28uXLdQ/29fXdvHlz48aNv/rqq/Xr1ztVIyFurX6Ho9OmTRsxYsTNmzcnTJhgNpvrHtyhQ4ePPvoIAGJjYy9cuCC8RkLcWr3fE7IbSxw6dOj999+3O3jq1Knjxo0rKSkZP358ZWWloAoJcXP1DqFtDYv3339/3759dsevWLGiXbt2x48f1+v1giokxM0JOTv62GOPvfnmm7bFs+se3LRp002bNnl7ey9atCg1NVVQkYS4M4GXKPR6fXR09JUrV1588UW7g3v37v3OO+9YLBaNRnP16lVheyTEXQkMoZeX18aNG5s1a5acnLxy5Uq74996663Bgwdfu3Zt8uTJiChsp4S4JeEX68PCwlatWgUAcXFxp0+ftrMblcpoNIaEhOzatYtd5yCEME51zIwcOXLq1Km3b99+7rnnysvL6x4cGhpqNBo9PDxmzZp18uRJZ/ZLiDtxtm1tyZIlnTt3PnPmzIwZM+wOHjZs2LRp0yoqKmJiYoqLi53cNSHuwdkQBgQEJCUl+fv7f/rpp0lJSXbHL168ODIy8pdffnn99ded3DUh7kGEBu6uXbsmJCSAw+1s69evDwgI+OKLLzZs2OD83glROnFmUbz22msjRoywrXxR9+CIiIiPP/4YAF555ZXs7GxRCiBEucQJoYeHB2tnS0tLmz9/vt3xWq127Nixt27dYrcoFKUGQhRKtPmEwcHBa9as8fT0nDt37v79++2OX7FiRdu2bX/44Qe6PyFp4MSc1Dtw4MD4+HjWGVNYWFj34KCgIHZzwoSEhD179ohYBiHKIvLM+rlz5/bv39/xdrbZs2dbLJYJEyZcu3ZN3EoIUQqRQ+jl5bVp06bg4OCtW7eyfpq6zZ49+/HHH6d2NtKQib/GjK2dbfr06Y60s7FlvHfu3ElLQpGGSZKFnkaNGjVlyhTH29lWr14NADNnzvzpp5+kqIcQOZNqtbWlS5eGh4efOXNm5syZdgc//fTTr7zySmTkuy+80IW62UhDI1UIAwMDk5KS/Pz8li1btn37drvjFy82lJXNOnHCKy5OoooIkSkJ1x3t1q3bwoUL4c/bV9Q92N/fe9MmCAiAzz+HjRulK4oQ2ZF28d/p06cPHz68qKho5swv7XWzQUQEfPghAMDLLwN1s5GGQ9oQsnY2tXrrtm36hQvtj4+NhTFj4NYtGDsWqJuNNBCSL4PfokWLGTOeqa6GuXPh8GH741euhDZtID0d3ntP6tIIkQVX3Iti4EB44w2oroZx48BeNxsEBcG6deDpCQsWwN69LqiOEM5cdEOY99+Hfv0gLw+0WvuDBwyA2bPBYoFJk6CgQPriCOHKRSH08oJ166BpU/jmG1i92v74d96BQYMgPx8mTQLqZiPuzXW3RmvfHj77DABg+nTIzLQzWKWCNWugeXNISYHly11QHSHcuPT+hKNHw/PPQ1kZxMSAvW42aN3aGtoZM+DUKemLI4QTV98kdNkyCA+HjAyYNcv+4GeegdhYqKiA8eOhrEz64pTvyJEjq1atsrvCCJEXdLnTp9HPDz08cNs2+4PLy7F7dwRArVb6ymTMx8cHAG7fvl3HGLPZ3LNnTwDo3Lnzrl27XFZbw1FcXJydnV1QUFD3sPomi0MIEfHDDxEAg4Px8mX7gzMy0N8fAXDjRukrkytHQoiIJpOpXbt27I9ArVZfunTJNeWRmpQRQosFn3oKAXDgQKyutj9++XIEwKAgzM6WvDZ5cjCEiFhRUWEwGBo3bgwAPj4+Op3ujz/+cEGFxEYZIUTE33/HVq0QAOfPd2h8TAwCYO/eWFkpcWWy5HgImfz8fK1Wq1KpAOCBBx5ITEw0m82SVkhsFBNCRNy3D1Uq9PLCw4ftDy4sxDZtEADnzJG+MvmpbwiZ9PT0/v37s7+Jnj17HnbkB02cpqQQIuLMmQiAYWFYWGh/8MGD6OmJKhXu3St9ZTIjLISIaLFYkpKSwsLCAMDDw2P06NGXHXkjTpygsBBWVWHfvgiAo0Y5NF6vRwAMDcXr1yWuTDYqKyt37drl6ekJALm5ucI2UlJSotfr/fz8ACAwMFCv15eXl4tbJ2HOnz+vsBAiYlYWNmmCAPj55/YHm804aBAC4LBhaLFIXxw/5eXlJpNJq9Xed9997EXM29ub5UfA6yGTm5ur0WjYn0hYWJjRaBS35gautLQ0JibGw8MDADw9PR1/IP8QIuLmzQiAgYGYmWl/cF4eNm+OALhsmfSVuVxJScnXX389duxYdnqT6d69e1xc3JNPPsk+7NChg8lkEryL7777rnv37mxTgwYN+umnn0Ssv8FKSUlhBxoAEBISsnv3bscfK4sQIuLEiQiA3bqhI0dJ33yDAOjri6dOSV+ZS5SWlppMJo1G06hRI1v2IiIi9Hp9Zo1npj179nTt2pV9dfDgwadPnxa2O7PZbDQa2WusSqXSaDTXrl0T6VtpcC5cuDB69Gj2S/H29o6Li6vvFuQSwpIS7NQJAfD11x0a/9JLCIBTpkhclsQKCwuNRqNarfb19WW/RZVKFRUVpdfrL1y4cM+HVFVVJSYmhoSEAICXl5dWq/39998F7z0+Pp6d8gkODk5ISKioqHDiu2lwiouL9Xo9+901atRo9uzZwi7JyiWEiHj8OPr4oIcHbt9uf3BpKS5Z4tCFfhkqKChg2WMBYNmLjo42GAxXrlxxZAs3btzQ6XReXl4A0KxZM4PBUFVVJayYc+fODRs2jJXRqVOnb7/9Vth2GhR2KHH//ffbDiWuXr0qeGsyCiEifvABAmCLFpifz7sUCeTl5SUmJqrVahYe9vadZe+3334TsMHMzMwhQ4awTYWHh6ekpAiuLTU1NSIigm3qiSeeOHPmjOBNub2jR4/26dOH/az69Olz9OhRJzcorxBaLKhW16OdTRFycnIMBkN0dDQ7bwYAfn5+arU6MTFR8JFkTSaTqX379rZ+0V9++UXYdiorKw0GQ9OmTdl7G51Od/PmTefLcyd5eXkajYb9Hlu3bm00Gi1inKOXVwgR8do1azvbwoW8S3FOVlZWrez5+/ur1Wqj0Xjr1i1x98X6RZs0aWLLj+BdFBQU6HQ6dlmyefPmBoOh2m2eDp1QWlqq1+v9/f0BICAgID4+vri4WKyNyy6EiLh7t7Wd7cgR3qXUX0ZGhl6vj4qKsp3kDAgIYNkT8dd2T7/++qutX7RVq1bO9IueOHHi0UcfZfX36NHj4MGD4paqIKzlqE2bNrZjjZycHHF3IccQIuKMGQiA7dujUg6IWPbCw8Nt2QsODtZoNCaTSfC1dWGOHz8eHR3NaoiKikpLSxO8KZPJ1LZtW9sfX3bDm8Ny/PjxAQMGsJ/AI4884swPsw4yDWFFBfbsKfe5vGYzpqVhXBwOH/6qLXstW7aMjY1NTU0VfLrSeTWfvFm/qOAn77KysoSEBHb10t/fX9zDMDljhxXssJwdVkh3WC7TECLixYuoVmNeHu867lJdjWlpqNNhaCgCIAD26bM1LCxMq9WaTCaO2aul1tsYvV5fVlYmbFNXrlyxnZAIDQ0V64SEPLETVKK8wXaQfEMoNxUVuHMnTp2KISHW7LED5jfewGPHzLL9oxTxhN6xY8f69u3LXvB79ep1RIlv2e0xmUwPPfSQ86ea64VCaEd5OZpMqNViixZ/yZ5Oh2lpimki379/f2RkJPvbGjhw4MmTJ4Vtx2KxGI3Gli1bsgNdjUYj7AqnDGVmZg4dOlSUi671pYAQbtmCwcHWf9Om2Rk8apR15Nq1Tu20rAxNJtRorDM82L+ICNTr8fhxp7bMi4hNHmxiFGvXcoOJUYWFhbb2o+DgYGfaj4RRQAg3bLgTA5XKzjT8IUOsI1evFrKv0lJr9ho1qp09R2Z4yF9RUVF8fDzLT1BQUEJCguCTtxcvXrQ1Lj/88MNJSUniluoCrBG3RYsWzjfiOkNhIQTArl3rWmZGWAhv3ECjEdVq9PW9k/aoKNTr8eJF578D2Tl//rxarWb56dChgzP52bt3b7du3dimHn/8ccETO1xPPpUrL4QAuGjR3w6uVwivX7dmz9vb+ihPT4yORoPBPZtXa0lNTe3SpYutXzQjI0PYdmTyeuI4ub2GKymEXbta0xIYiH933cuREObmYmIiqtXo5VU7e+5ylsFR7HR8UFAQOx2v1WqvC104hPs7K0fc/W7Wxa0U96SkEI4Yga++av3/0KH3HlxHCHNy0GDA6Gj08LCO8fNDtRoTE1HeT9ySYxOj2IVpNjFK8IVpjucY6ybn87oKC+G1a3dOVyYn32Pw3SHMyqqdPX9/VKvRaESJr8EqzNmzZ//1r3+x/Di5kD6Xq211qHmFs3fv3nK7wqmwEOKfcw7ZQol3d1DVCmFh4Z1jzqAgnDABk5NRaN9Ig1BrIf2srCxh23F938k9KaLXR3khrKiwLoQBgP/5T+3Bd78SjhmDL7yAKSlISzc4SMSF9F3ZgVmLgrpelRdCRNyxw/oZL6/aaz05eZ2Q2Ii4kL5r5iLUpKz5H4oMIaL1fjIA2KsX1vzboBCKS6yF9F0wK49R4kxIpYbw8mUMDLR+fsWKO5+nEIqO5efBBx8EpxfSLy0ttR0iij4/XblrAig1hIg4f77188HBaFs1k0IoEREX0s/JybFdK+/Xr5/ztSl9dRwFh7CiAsPDrV+yzf2lEEpKxIX02Zpla9ascbIkN1gnTsEhRMTU1Dutnj/8gEghdAmxFtKvrq525oJBrRVTd+zYIXhTfCk7hIg4Zoz1qwMHIlIIXYXvQvputna44kP422/YtKl1wNatFEKXqhkGNjFK6jCw8LNmcbe5i4biQ4iIH35oHdCxIw4eTCF0tXPnzg0fPpwdFnbs2FG6hfTd9X5S7hDCqirs3t06xjYpiULoYpKeIHHvOyu6QwgR8dChO/3ZFEJepLhU0BDuMewmIUTEyZMphLIg1kVz1iQQFhbmfJOAzLlPCAsK/rIYIYWQrx9//LFm+9iBAwfq9fD09PR+/fqxh/fq1Utwu5wiuE8IEXHFCgqhvAhopM7Pz7dNPnKycVwpPBAR5C0zE0wmAIAOHeDZZ+saabHA0qVQUQEAMHQo/HkijfBUXl6+dOnSefPmlZSU+Pv763S6t99+m82TuufI+fPnFxcX+/j4xMbGzps3754j3Q3vZwHSINSaXLtgwYJar2/Lli2rOZn40qVLvEp1PQohcZ1jx47Z3ukFBgauWrUKETdu3MhOqAJAZGTk/v37eZfpago4HCXuxGKxxMfHf/TRRxaLBQACAwNLS0sBQKVS6XS6xYsXs9OqDQqFkHBw9erVkSNHHjlyhH3Yo0eP5ORk25TfhoZCSLhJTU3dsmXLM888M2TIEN618EQhJIQzFe8CCGnoKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQzv4f6jkImFkQ6PQAAAAASUVORK5CYII=",
141
+ "text/plain": [
142
+ "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=300x300>"
143
+ ]
144
+ },
145
+ "execution_count": 11,
146
+ "metadata": {},
147
+ "output_type": "execute_result"
148
+ }
149
+ ],
150
+ "source": [
151
+ "# Generate Mol Viz\n",
152
+ "from rdkit import Chem\n",
153
+ "from rdkit.Chem import Draw\n",
154
+ "\n",
155
+ "input_ids = tokenizer(\"<s>\", return_tensors=\"pt\").input_ids.to(device)\n",
156
+ "gen = model.generate(input_ids, max_length=256, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n",
157
+ "generatedmol = tokenizer.decode(gen[0], skip_special_tokens=True)\n",
158
+ "\n",
159
+ "test = generatedmol.replace(' ', '')\n",
160
+ "csmi_gen = sf.decoder(test)\n",
161
+ "print(csmi_gen)\n",
162
+ "mol = Chem.MolFromSmiles(csmi_gen)\n",
163
+ "\n",
164
+ "# Draw the molecule\n",
165
+ "Draw.MolToImage(mol)"
166
+ ]
167
+ }
168
+ ],
169
+ "metadata": {
170
+ "kernelspec": {
171
+ "display_name": "base",
172
+ "language": "python",
173
+ "name": "python3"
174
+ },
175
+ "language_info": {
176
+ "codemirror_mode": {
177
+ "name": "ipython",
178
+ "version": 3
179
+ },
180
+ "file_extension": ".py",
181
+ "mimetype": "text/x-python",
182
+ "name": "python",
183
+ "nbconvert_exporter": "python",
184
+ "pygments_lexer": "ipython3",
185
+ "version": "3.13.0"
186
+ }
187
+ },
188
+ "nbformat": 4,
189
+ "nbformat_minor": 5
190
+ }
train-withmtp.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========================
2
+ # Train with NTP + MTP
3
+ # by gbyuvd
4
+ # ========================
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import json
10
+ import os
11
+ import math
12
+ from typing import List, Union, Optional, Tuple, Dict, Any
13
+ from transformers.tokenization_utils_base import BatchEncoding
14
+ from transformers import Qwen3Config, Qwen3ForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
15
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel
16
+ from datasets import load_dataset, DatasetDict
17
+ import pandas as pd
18
+ from torch.utils.data import Dataset, DataLoader, random_split
19
+ from sklearn.model_selection import train_test_split
20
+ from ranger21 import Ranger21
21
+ from tqdm.notebook import tqdm
22
+ from FastChemTokenizerHF import FastChemTokenizerSelfies
23
+ from ChemQ3MTP import ChemQ3MTP
24
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
+ from transformers import TrainerCallback
26
+ import datetime
27
+
28
+ # ==============================
29
+ # Load external configuration
30
+ # ==============================
31
+ with open("config.json", "r") as f:
32
+ CONFIG = json.load(f)
33
+
34
+ TRAINING_CFG = CONFIG["training"]
35
+ MODEL_CFG = CONFIG["model"]
36
+ GENERATION_CFG = CONFIG.get("generation", {})
37
+
38
+ # Training params
39
+ BATCH_SIZE = TRAINING_CFG["batch_size"]
40
+ NUM_EPOCHS = TRAINING_CFG["num_epochs"]
41
+ LEARNING_RATE = TRAINING_CFG["learning_rate"]
42
+ WEIGHT_DECAY = TRAINING_CFG["weight_decay"]
43
+ GRAD_ACCUM_STEPS = TRAINING_CFG["gradient_accumulation_steps"]
44
+ TOKENIZE_BATCH_SIZE = TRAINING_CFG["tokenize_batch_size"]
45
+ TRAIN_SPLIT_RATIO = TRAINING_CFG["train_split_ratio"]
46
+ VAL_SPLIT_RATIO = TRAINING_CFG["val_split_ratio"]
47
+ TEST_SPLIT_RATIO = TRAINING_CFG["test_split_ratio"]
48
+ INCLUDE_FOR_METRICS = TRAINING_CFG.get("include_for_metrics", ["input_ids", "attention_mask", "labels"])
49
+ # ==============================
50
+
51
+ class LossLoggerCallback(TrainerCallback):
52
+ def __init__(self, log_file="training_losses.txt", with_timestamp=False):
53
+ self.log_file = log_file
54
+ self.with_timestamp = with_timestamp
55
+ with open(self.log_file, "w") as f:
56
+ if self.with_timestamp:
57
+ f.write("time\tstep\tloss\teval_loss\n")
58
+ else:
59
+ f.write("step\tloss\teval_loss\n")
60
+
61
+ def on_log(self, args, state, control, logs=None, **kwargs):
62
+ if logs is None:
63
+ return
64
+ step = state.global_step
65
+ loss = logs.get("loss")
66
+ eval_loss = logs.get("eval_loss")
67
+
68
+ with open(self.log_file, "a") as f:
69
+ if self.with_timestamp:
70
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
71
+ f.write(f"{ts}\t{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
72
+ else:
73
+ f.write(f"{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
74
+
75
+
76
+ def main():
77
+ # --- Load the tokenizer ---
78
+ tokenizer = FastChemTokenizerSelfies.from_pretrained("./selftok_core")
79
+
80
+ out = tokenizer("[C] [=C] [Branch1]", return_tensors="pt")
81
+ print(out.input_ids)
82
+ print(out.attention_mask)
83
+ out = out.to("cuda" if torch.cuda.is_available() else "cpu")
84
+ print(out.input_ids.device)
85
+
86
+ # --- Define config ---
87
+ config = Qwen3Config(
88
+ vocab_size=len(tokenizer),
89
+ bos_token_id=tokenizer.bos_token_id,
90
+ eos_token_id=tokenizer.eos_token_id,
91
+ pad_token_id=tokenizer.pad_token_id,
92
+ tie_word_embeddings=True,
93
+ use_cache=False,
94
+ **MODEL_CFG
95
+ )
96
+
97
+ model = ChemQ3MTP(config, num_future_tokens=3)
98
+
99
+ def count_parameters(model):
100
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
101
+
102
+ print(f"Enhanced model has {count_parameters(model):,} trainable parameters.")
103
+
104
+ batch_size, seq_len = 2, 32
105
+ dummy_input = torch.randint(
106
+ low=0,
107
+ high=len(tokenizer),
108
+ size=(batch_size, seq_len),
109
+ dtype=torch.long,
110
+ )
111
+ with torch.no_grad():
112
+ outputs = model(dummy_input)
113
+ logits = outputs.logits
114
+ print(f"Input shape: {dummy_input.shape}")
115
+ print(f"Logits shape: {logits.shape}")
116
+
117
+ print("Loading dataset...")
118
+ dataset = load_dataset(
119
+ 'csv',
120
+ data_files='./data/sample_all_14k.csv',
121
+ split='train',
122
+ streaming=True
123
+ )
124
+
125
+ print("Shuffling and splitting dataset...")
126
+ shuffled_dataset = dataset.shuffle(seed=42, buffer_size=10000)
127
+
128
+ total_lines = 14000
129
+ test_size = int(TEST_SPLIT_RATIO * total_lines)
130
+ val_size = int(VAL_SPLIT_RATIO * total_lines)
131
+ train_size = total_lines - test_size - val_size
132
+
133
+ test_dataset = shuffled_dataset.take(test_size)
134
+ remaining = shuffled_dataset.skip(test_size)
135
+ val_dataset = remaining.take(val_size)
136
+ train_dataset = remaining.skip(val_size)
137
+
138
+ print(f"Dataset split: train={train_size}, val={val_size}, test={test_size}")
139
+
140
+ def tokenize_function(examples):
141
+ batch_results = {"input_ids": [], "attention_mask": [], "labels": []}
142
+ smiles_list = examples['SELFIES'] if isinstance(examples['SELFIES'], list) else [examples['SELFIES']]
143
+ for smiles in smiles_list:
144
+ tokenized = tokenizer(
145
+ smiles,
146
+ truncation=True,
147
+ padding=False,
148
+ max_length=MODEL_CFG["max_position_embeddings"],
149
+ return_tensors=None,
150
+ add_special_tokens=True
151
+ )
152
+ input_ids = tokenized["input_ids"]
153
+ attention_mask = tokenized["attention_mask"]
154
+ labels = input_ids.copy()
155
+ batch_results["input_ids"].append(input_ids)
156
+ batch_results["attention_mask"].append(attention_mask)
157
+ batch_results["labels"].append(labels)
158
+ return batch_results
159
+
160
+ print("Tokenizing datasets...")
161
+ train_dataset = train_dataset.map(tokenize_function, batched=True, batch_size=TOKENIZE_BATCH_SIZE, remove_columns=["SELFIES"])
162
+ val_dataset = val_dataset.map(tokenize_function, batched=True, batch_size=TOKENIZE_BATCH_SIZE, remove_columns=["SELFIES"])
163
+
164
+ class EnhancedDataCollator:
165
+ def __init__(self, tokenizer, pad_to_multiple_of=8):
166
+ self.tokenizer = tokenizer
167
+ self.pad_to_multiple_of = pad_to_multiple_of
168
+ def __call__(self, features):
169
+ max_length = max(len(f["input_ids"]) for f in features)
170
+ if self.pad_to_multiple_of:
171
+ max_length = ((max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
172
+ batch = {"input_ids": [], "attention_mask": [], "labels": []}
173
+ for feature in features:
174
+ input_ids = feature["input_ids"]
175
+ attention_mask = feature["attention_mask"]
176
+ labels = feature["labels"]
177
+ padding_length = max_length - len(input_ids)
178
+ padded_input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
179
+ padded_attention_mask = attention_mask + [0] * padding_length
180
+ padded_labels = labels + [-100] * padding_length
181
+ batch["input_ids"].append(padded_input_ids)
182
+ batch["attention_mask"].append(padded_attention_mask)
183
+ batch["labels"].append(padded_labels)
184
+ batch = {key: torch.tensor(values, dtype=torch.long) for key, values in batch.items()}
185
+ return batch
186
+
187
+ data_collator = EnhancedDataCollator(tokenizer, pad_to_multiple_of=8)
188
+
189
+ def create_enhanced_optimizer(model_params):
190
+ num_batches_per_epoch = train_size // BATCH_SIZE
191
+ optimizer_params = {
192
+ 'lr': LEARNING_RATE,
193
+ 'weight_decay': WEIGHT_DECAY,
194
+ 'use_adabelief': True,
195
+ 'use_cheb': False,
196
+ 'use_warmup': True,
197
+ 'use_madgrad': True,
198
+ 'num_epochs': NUM_EPOCHS,
199
+ 'using_gc': True,
200
+ 'warmdown_active': True,
201
+ 'num_batches_per_epoch': num_batches_per_epoch
202
+ }
203
+ return Ranger21(model_params, **optimizer_params)
204
+
205
+ from torch.optim.lr_scheduler import LambdaLR
206
+ class EnhancedCustomTrainer(Trainer):
207
+ def create_optimizer(self):
208
+ self.optimizer = create_enhanced_optimizer(self.model.parameters())
209
+ return self.optimizer
210
+ def create_scheduler(self, num_training_steps, optimizer=None):
211
+ if optimizer is None:
212
+ optimizer = self.optimizer
213
+ self.lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
214
+ return self.lr_scheduler
215
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
216
+ outputs = model(**inputs)
217
+ loss = outputs.loss
218
+ return (loss, outputs) if return_outputs else loss
219
+
220
+ steps_per_epoch = train_size // BATCH_SIZE
221
+ total_steps = steps_per_epoch * NUM_EPOCHS
222
+
223
+ training_args = TrainingArguments(
224
+ output_dir='./chemq3minipret',
225
+ max_steps=total_steps,
226
+ per_device_train_batch_size=BATCH_SIZE,
227
+ per_device_eval_batch_size=BATCH_SIZE,
228
+ gradient_accumulation_steps=GRAD_ACCUM_STEPS,
229
+ logging_dir='./gptlo-1',
230
+ logging_strategy="steps",
231
+ logging_steps=max(1, steps_per_epoch // 4),
232
+ eval_strategy="steps",
233
+ eval_steps=max(1, steps_per_epoch // 4),
234
+ save_strategy="steps",
235
+ save_steps=steps_per_epoch,
236
+ save_total_limit=1,
237
+ dataloader_num_workers=0,
238
+ dataloader_pin_memory=False,
239
+ remove_unused_columns=False,
240
+ prediction_loss_only=False,
241
+ fp16=torch.cuda.is_available(),
242
+ gradient_checkpointing=True,
243
+ dataloader_drop_last=True,
244
+ report_to=None,
245
+ include_for_metrics=INCLUDE_FOR_METRICS,
246
+ )
247
+
248
+ print("Initializing enhanced trainer with MTP capabilities...")
249
+ trainer = EnhancedCustomTrainer(
250
+ model=model,
251
+ args=training_args,
252
+ train_dataset=train_dataset,
253
+ eval_dataset=val_dataset,
254
+ data_collator=data_collator,
255
+ processing_class=tokenizer,
256
+ callbacks=[LossLoggerCallback("training_losses.txt", with_timestamp=True)]
257
+ )
258
+
259
+ model.set_mtp_training(True)
260
+ print(" MTP training mode enabled")
261
+
262
+ print("Starting enhanced training with MTP and Horizon Loss...")
263
+ try:
264
+ print("\n Phase 1: Warmup with standard Causal LM...")
265
+ model.set_mtp_training(False)
266
+ warmup_steps = max(1, total_steps // 5)
267
+ trainer.args.max_steps = warmup_steps
268
+ trainer.train()
269
+ print("\n Phase 2: Full MTP + Horizon Loss training...")
270
+ model.set_mtp_training(True)
271
+ trainer.args.max_steps = total_steps
272
+ trainer.train(resume_from_checkpoint=True)
273
+ print("Enhanced training completed successfully!")
274
+ trainer.save_model("./enhanced-qwen3-final")
275
+ tokenizer.save_pretrained("./enhanced-qwen3-final")
276
+ training_config = {
277
+ "model_type": "EnhancedQwen3ForCausalLM",
278
+ "num_future_tokens": 3,
279
+ "horizon_loss_enabled": True,
280
+ "mtp_head_enabled": True,
281
+ "training_phases": ["causal_lm_warmup", "mtp_horizon_training"],
282
+ "total_parameters": count_parameters(model),
283
+ }
284
+ config_path = "./enhanced-qwen3-final/training_config.json"
285
+ with open(config_path, "w") as f:
286
+ json.dump(training_config, f, indent=2)
287
+ print(f" Enhanced model, tokenizer, and config saved!")
288
+ except Exception as e:
289
+ print(f"Enhanced training failed with error: {e}")
290
+ import traceback
291
+ traceback.print_exc()
292
+ return
293
+
294
+ print("\nmTesting enhanced generation capabilities...")
295
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
296
+ model.to(device)
297
+ model.eval()
298
+ try:
299
+ print("\n--- Standard Generation Test ---")
300
+ input_ids = tokenizer("<s> [C]", return_tensors="pt").input_ids.to(device)
301
+ with torch.no_grad():
302
+ model.set_mtp_training(False)
303
+ gen = model.generate(
304
+ input_ids,
305
+ max_length=GENERATION_CFG.get("max_length", 64),
306
+ top_k=GENERATION_CFG.get("top_k", 50),
307
+ top_p=GENERATION_CFG.get("top_p", 0.9),
308
+ temperature=GENERATION_CFG.get("temperature", 0.8),
309
+ do_sample=GENERATION_CFG.get("do_sample", True),
310
+ pad_token_id=tokenizer.pad_token_id,
311
+ eos_token_id=tokenizer.eos_token_id,
312
+ num_return_sequences=GENERATION_CFG.get("num_return_sequences", 3),
313
+ )
314
+ for i, sequence in enumerate(gen):
315
+ result = tokenizer.decode(sequence, skip_special_tokens=True)
316
+ print(f"Generated SELFIES {i+1}: {result}")
317
+ print("\n--- MTP Analysis Test ---")
318
+ model.set_mtp_training(True)
319
+ test_smiles = "[C]"
320
+ test_input = tokenizer(test_smiles, return_tensors="pt", add_special_tokens=True).to(device)
321
+ with torch.no_grad():
322
+ outputs = model(**test_input)
323
+ if hasattr(model.mtp_head, 'prediction_heads'):
324
+ hidden_states = model.model(test_input['input_ids']).last_hidden_state
325
+ mtp_outputs = model.mtp_head(hidden_states)
326
+ print(f"Input SELFIES: {test_smiles}")
327
+ print(f"Tokenized: {tokenizer.convert_ids_to_tokens(test_input['input_ids'][0].tolist())}")
328
+ for i, (key, logits) in enumerate(mtp_outputs.items()):
329
+ top_tokens = torch.topk(logits[0], k=3, dim=-1)
330
+ print(f"\n{key} predictions:")
331
+ for pos in range(min(5, logits.size(1))):
332
+ pos_preds = []
333
+ for j in range(3):
334
+ token_id = top_tokens.indices[pos, j].item()
335
+ prob = torch.softmax(logits[0, pos], dim=-1)[token_id].item()
336
+ token = tokenizer.id_to_token.get(token_id, '<UNK>')
337
+ pos_preds.append(f"{token}({prob:.3f})")
338
+ print(f" Position {pos}: {', '.join(pos_preds)}")
339
+ print("\nEnhanced generation tests completed!")
340
+ except Exception as e:
341
+ print(f"Enhanced generation test failed: {e}")
342
+ import traceback
343
+ traceback.print_exc()
344
+
345
+ print("\nEnhanced Model Analysis:")
346
+ print(f"Total parameters: {count_parameters(model):,}")
347
+ mtp_params = sum(p.numel() for p in model.mtp_head.parameters() if p.requires_grad)
348
+ horizon_params = sum(p.numel() for p in model.horizon_loss.parameters() if p.requires_grad)
349
+ base_params = count_parameters(model) - mtp_params - horizon_params
350
+ print(f"Base model parameters: {base_params:,}")
351
+ print(f"MTP head parameters: {mtp_params:,}")
352
+ print(f"Horizon loss parameters: {horizon_params:,}")
353
+ print(f"Enhancement overhead: {((mtp_params + horizon_params) / base_params * 100):.2f}%")
354
+ print(f"\n Enhanced Model Architecture:")
355
+ print(f"- Base Model: Qwen3 with {config.num_hidden_layers} layers")
356
+ print(f"- Hidden Size: {config.hidden_size}")
357
+ print(f"- Attention Heads: {config.num_attention_heads}")
358
+ print(f"- Vocab Size: {config.vocab_size}")
359
+ print(f"- MTP Future Tokens: {model.mtp_head.num_future_tokens}")
360
+ print(f"- Horizon Loss Weights: Learnable")
361
+ print(f"- Training Mode: {'MTP + Horizon Loss' if model.use_mtp_training else 'Standard Causal LM'}")
362
+ print("\n Enhanced training pipeline completed successfully!")
363
+
364
+ if __name__ == "__main__":
365
+ main()
train_ppokl_withsa.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Refactored PPO-KL training script using ChemQ3MTP module
3
+
4
+ import os
5
+ import torch
6
+ from tqdm import tqdm
7
+ from FastChemTokenizerHF import FastChemTokenizerSelfies
8
+ from ChemQ3MTP import ChemQ3MTP, CurriculumManager
9
+
10
+
11
+ def main():
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"🚀 Using device: {device}")
14
+
15
+ # --- Load tokenizer ---
16
+ tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core")
17
+
18
+ # --- Load model ---
19
+ model = ChemQ3MTP.from_pretrained("../pretrained/sample-e1-mtp")
20
+ model.tokenizer = tokenizer
21
+ model.to(device)
22
+
23
+ # --- RL fine-tuning setup ---
24
+ print("\n🎯 Phase 2: RL Fine-tuning with PPO + Curriculum Learning")
25
+ model.set_mtp_training(False)
26
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
27
+ curriculum = CurriculumManager(start_len=10, max_len=35, step_increase=5, steps_per_level=70)
28
+ baseline = None
29
+ gamma = 0.95
30
+
31
+ # Dummy input (BOS-only batch)
32
+ batch_size = 4
33
+ dummy_input = tokenizer([tokenizer.bos_token] * batch_size, return_tensors="pt", padding=True)
34
+ input_ids = dummy_input.input_ids.to(device)
35
+
36
+ # Training config
37
+ total_steps = 14000
38
+ checkpoint_steps = {total_steps // 4, total_steps // 2, 3 * total_steps // 4, total_steps}
39
+ checkpoint_dir = "./ppo_checkpoints"
40
+ os.makedirs(checkpoint_dir, exist_ok=True)
41
+
42
+ # --- RL Training Loop with tqdm ---
43
+ for step in tqdm(range(total_steps), desc="RL Training"):
44
+ max_new_tokens = curriculum.get_max_new_tokens()
45
+
46
+ # === PPO Rollout ===
47
+ with torch.no_grad():
48
+ selfies_list, old_log_probs, _, old_action_probs = model.generate_with_logprobs(
49
+ input_ids=input_ids,
50
+ max_new_tokens=max_new_tokens,
51
+ temperature=1.0,
52
+ top_k=50,
53
+ top_p=0.95,
54
+ do_sample=True,
55
+ return_probs=True
56
+ )
57
+ old_log_probs = old_log_probs.detach()
58
+ old_action_probs = old_action_probs.detach()
59
+
60
+ # === PPO Update ===
61
+ ppo_result = model.ppo_step(
62
+ input_ids=input_ids,
63
+ old_log_probs=old_log_probs,
64
+ old_action_probs=old_action_probs,
65
+ tokenizer=tokenizer,
66
+ max_new_tokens=max_new_tokens,
67
+ # validity_weight=1.0, # only used in ChemQ3 mode
68
+ # lipinski_weight=1.0, # only used in ChemQ3 mode
69
+ entropy_weight=0.01,
70
+ clip_epsilon=0.2,
71
+ baseline=baseline,
72
+ reward_mode="sa", # 🔑 SA-only mode
73
+ )
74
+
75
+
76
+
77
+ loss = ppo_result['loss']
78
+ optimizer.zero_grad(set_to_none=True) # slightly more efficient than zeroing
79
+ loss.backward()
80
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
81
+ optimizer.step()
82
+
83
+ # === Update baseline ===
84
+ reward_tensor = torch.tensor(ppo_result['avg_reward'], device=device)
85
+ baseline = reward_tensor if baseline is None else gamma * baseline + (1 - gamma) * reward_tensor
86
+
87
+ # Curriculum update
88
+ curriculum.step()
89
+
90
+ # Checkpointing
91
+ if (step + 1) in checkpoint_steps:
92
+ checkpoint_path = os.path.join(checkpoint_dir, f"model_step_{step+1}")
93
+ model.save_pretrained(checkpoint_path)
94
+ tokenizer.save_pretrained(checkpoint_path)
95
+ torch.save({
96
+ 'step': step + 1,
97
+ 'optimizer_state_dict': optimizer.state_dict(),
98
+ 'baseline': baseline.item(),
99
+ 'curriculum_state': {
100
+ 'current_max_len': curriculum.current_max_len,
101
+ 'step_counter': curriculum.step_counter
102
+ }
103
+ }, os.path.join(checkpoint_path, 'training_state.pt'))
104
+ print(f"\n💾 Checkpoint saved at step {step+1} -> {checkpoint_path}")
105
+
106
+ # Logging every 50 steps
107
+ if step % 50 == 0:
108
+ print(f"\n[RL Step {step}] "
109
+ f"Loss={loss.item():.4f} | "
110
+ f"Valid={ppo_result['validity_rate']:.3f} | "
111
+ f"Lipinski={ppo_result['lipinski_score']:.3f} | "
112
+ f"Reward={ppo_result['avg_reward']:.3f} | "
113
+ f"Entropy={ppo_result['entropy']:.3f} | "
114
+ f"EntropyW={ppo_result['entropy_weight']:.4f}")
115
+
116
+ sample_selfies = ppo_result['generated_selfies'][0][:100]
117
+ sample_smiles = ppo_result['generated_smiles'][0] or "Invalid"
118
+ print(f" Sample SELFIES: {sample_selfies}")
119
+ print(f" Sample SMILES: {sample_smiles}")
120
+
121
+
122
+ sample_selfies = ppo_result['generated_selfies'][0][:100]
123
+ sample_smiles = ppo_result['generated_smiles'][0] or "Invalid"
124
+ print(f" Sample SELFIES: {sample_selfies}")
125
+ print(f" Sample SMILES: {sample_smiles}")
126
+
127
+
128
+ print("🎉 Training complete!")
129
+
130
+ if __name__ == "__main__":
131
+ main()