Upload 7 files
Browse files- ChemQ3MTP.py +753 -0
- FastChemTokenizerHF.py +769 -0
- LICENSE +21 -0
- config.json +34 -0
- demo_test_mtpresult.ipynb +190 -0
- train-withmtp.py +365 -0
- train_ppokl_withsa.py +131 -0
ChemQ3MTP.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========================
|
| 2 |
+
# ChemQ3-MTP
|
| 3 |
+
# MODEL COMPONENTS
|
| 4 |
+
# by gbyuvd
|
| 5 |
+
# ========================
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.distributions import Categorical
|
| 12 |
+
from typing import List, Union, Optional, Tuple, Dict, Any
|
| 13 |
+
from transformers import Qwen3Config, Qwen3ForCausalLM, AutoTokenizer
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
from rdkit.Chem import Descriptors, Lipinski
|
| 16 |
+
import selfies as sf
|
| 17 |
+
from rdkit import RDLogger
|
| 18 |
+
RDLogger.DisableLog('rdApp.*') # suppress all SMILES parse messages
|
| 19 |
+
import json
|
| 20 |
+
from typing import List, Union, Optional, Tuple
|
| 21 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 22 |
+
from FastChemTokenizer import FastChemTokenizerSelfies
|
| 23 |
+
import numpy as np
|
| 24 |
+
from collections import Counter
|
| 25 |
+
from rdkit.Chem import Descriptors, Lipinski, rdMolDescriptors
|
| 26 |
+
|
| 27 |
+
# ========================
|
| 28 |
+
# UTILS: SELFIES -> SMILES -> VALIDITY & LIPINSKI
|
| 29 |
+
# ========================
|
| 30 |
+
|
| 31 |
+
def selfies_to_smiles(selfies_str: str) -> str | None:
|
| 32 |
+
"""Convert SELFIES string to SMILES, handling tokenizer artifacts."""
|
| 33 |
+
try:
|
| 34 |
+
clean_selfies = selfies_str.replace(" ", "")
|
| 35 |
+
return sf.decoder(clean_selfies)
|
| 36 |
+
except Exception:
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
def is_valid_smiles(smiles: str) -> bool:
|
| 40 |
+
if not isinstance(smiles, str) or len(smiles.strip()) == 0:
|
| 41 |
+
return False
|
| 42 |
+
return Chem.MolFromSmiles(smiles.strip()) is not None
|
| 43 |
+
|
| 44 |
+
# SA Classifier
|
| 45 |
+
from transformers import pipeline
|
| 46 |
+
|
| 47 |
+
# Optional: lazy load so we don’t reload every time
|
| 48 |
+
_sa_classifier = None
|
| 49 |
+
def get_sa_classifier():
|
| 50 |
+
global _sa_classifier
|
| 51 |
+
if _sa_classifier is None:
|
| 52 |
+
_sa_classifier = pipeline("text-classification", model="gbyuvd/synthaccess-chemselfies")
|
| 53 |
+
return _sa_classifier
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def compute_sa_reward(selfies_str: str) -> float:
|
| 57 |
+
"""Reward molecules with easy synthetic accessibility (SA)."""
|
| 58 |
+
try:
|
| 59 |
+
classifier = get_sa_classifier()
|
| 60 |
+
result = classifier(selfies_str, truncation=True, max_length=128)[0]
|
| 61 |
+
if result["label"].lower() == "easy":
|
| 62 |
+
return result["score"]
|
| 63 |
+
else:
|
| 64 |
+
return -result["score"] # penalize "Hard"
|
| 65 |
+
except Exception:
|
| 66 |
+
return 0.0
|
| 67 |
+
|
| 68 |
+
# ==========================
|
| 69 |
+
# Reward Components
|
| 70 |
+
# ==========================
|
| 71 |
+
def compute_biological_diversity_score(mol) -> float:
|
| 72 |
+
"""Reward molecules with diverse CHONP atoms, normalized to [0,1]."""
|
| 73 |
+
if mol is None:
|
| 74 |
+
return 0.0
|
| 75 |
+
try:
|
| 76 |
+
atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
|
| 77 |
+
atom_counts = Counter(atoms)
|
| 78 |
+
bio_elements = {"C", "H", "O", "N", "P"}
|
| 79 |
+
present_bio_elements = set(atoms) & bio_elements
|
| 80 |
+
|
| 81 |
+
if len(present_bio_elements) < 2:
|
| 82 |
+
return 0.0
|
| 83 |
+
|
| 84 |
+
base_score = 0.3
|
| 85 |
+
diversity_bonus = (len(present_bio_elements) - 2) / 3 * 0.4
|
| 86 |
+
|
| 87 |
+
total_bio_atoms = sum(atom_counts.get(e, 0) for e in present_bio_elements)
|
| 88 |
+
if total_bio_atoms > 0:
|
| 89 |
+
bio_probs = [atom_counts.get(e, 0) / total_bio_atoms for e in present_bio_elements]
|
| 90 |
+
if len(bio_probs) > 1:
|
| 91 |
+
entropy = -sum(p * np.log2(p) for p in bio_probs if p > 0)
|
| 92 |
+
max_entropy = np.log2(len(bio_probs))
|
| 93 |
+
entropy_bonus = (entropy / max_entropy) * 0.3
|
| 94 |
+
else:
|
| 95 |
+
entropy_bonus = 0.0
|
| 96 |
+
else:
|
| 97 |
+
entropy_bonus = 0.0
|
| 98 |
+
|
| 99 |
+
return min(1.0, base_score + diversity_bonus + entropy_bonus)
|
| 100 |
+
except Exception:
|
| 101 |
+
return 0.0
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def compute_charge_neutrality_score(mol) -> float:
|
| 105 |
+
"""Reward if molecule is globally neutral (formal charge = 0)."""
|
| 106 |
+
if mol is None:
|
| 107 |
+
return 0.0
|
| 108 |
+
try:
|
| 109 |
+
return 1.0 if Chem.rdmolops.GetFormalCharge(mol) == 0 else 0.0
|
| 110 |
+
except Exception:
|
| 111 |
+
return 0.0
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def compute_local_charge_penalty(mol) -> float:
|
| 115 |
+
"""
|
| 116 |
+
Penalize carbocations/anions.
|
| 117 |
+
Returns 1.0 if no charged atoms, decreases with fraction charged.
|
| 118 |
+
"""
|
| 119 |
+
if mol is None:
|
| 120 |
+
return 0.0
|
| 121 |
+
try:
|
| 122 |
+
charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]
|
| 123 |
+
if not charges:
|
| 124 |
+
return 1.0
|
| 125 |
+
charged_atoms = sum(1 for c in charges if c != 0)
|
| 126 |
+
total_atoms = len(charges)
|
| 127 |
+
return max(0.0, 1.0 - (charged_atoms / total_atoms))
|
| 128 |
+
except Exception:
|
| 129 |
+
return 0.0
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compute_enhanced_lipinski_reward(mol) -> float:
|
| 133 |
+
"""Soft Lipinski scoring with partial credit."""
|
| 134 |
+
if mol is None:
|
| 135 |
+
return 0.0
|
| 136 |
+
try:
|
| 137 |
+
mw = Descriptors.MolWt(mol)
|
| 138 |
+
logp = Descriptors.MolLogP(mol)
|
| 139 |
+
hbd = Lipinski.NumHDonors(mol)
|
| 140 |
+
hba = Lipinski.NumHAcceptors(mol)
|
| 141 |
+
scores = []
|
| 142 |
+
|
| 143 |
+
# MW
|
| 144 |
+
if 250 <= mw <= 500: scores.append(1.0)
|
| 145 |
+
elif 150 <= mw < 250: scores.append(0.5)
|
| 146 |
+
elif 500 < mw <= 600: scores.append(0.7)
|
| 147 |
+
else: scores.append(0.0)
|
| 148 |
+
|
| 149 |
+
# LogP
|
| 150 |
+
if -1 <= logp <= 5: scores.append(1.0)
|
| 151 |
+
elif -2 <= logp < -1 or 5 < logp <= 6: scores.append(0.5)
|
| 152 |
+
else: scores.append(0.0)
|
| 153 |
+
|
| 154 |
+
# Donors
|
| 155 |
+
scores.append(1.0 if hbd <= 5 else max(0.0, 1.0 - 0.2 * (hbd - 5)))
|
| 156 |
+
# Acceptors
|
| 157 |
+
scores.append(1.0 if hba <= 10 else max(0.0, 1.0 - 0.1 * (hba - 10)))
|
| 158 |
+
|
| 159 |
+
return sum(scores) / len(scores)
|
| 160 |
+
except Exception:
|
| 161 |
+
return 0.0
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def compute_structural_complexity_reward(mol) -> float:
|
| 165 |
+
"""Reward moderate complexity: 1–3 rings and some flexibility."""
|
| 166 |
+
if mol is None:
|
| 167 |
+
return 0.0
|
| 168 |
+
try:
|
| 169 |
+
ring_count = rdMolDescriptors.CalcNumRings(mol)
|
| 170 |
+
if 1 <= ring_count <= 3: ring_score = 1.0
|
| 171 |
+
elif ring_count == 0: ring_score = 0.3
|
| 172 |
+
elif ring_count <= 5: ring_score = 0.7
|
| 173 |
+
else: ring_score = 0.1
|
| 174 |
+
|
| 175 |
+
rot_bonds = Descriptors.NumRotatableBonds(mol)
|
| 176 |
+
if 2 <= rot_bonds <= 8: flex_score = 1.0
|
| 177 |
+
elif rot_bonds <= 12: flex_score = 0.7
|
| 178 |
+
elif rot_bonds in (0, 1): flex_score = 0.5
|
| 179 |
+
else: flex_score = 0.2
|
| 180 |
+
|
| 181 |
+
return (ring_score + flex_score) / 2
|
| 182 |
+
except Exception:
|
| 183 |
+
return 0.0
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ==========================
|
| 187 |
+
# Unified Reward
|
| 188 |
+
# ==========================
|
| 189 |
+
def compute_comprehensive_reward(selfies_str: str) -> dict[str, float]:
|
| 190 |
+
smiles = selfies_to_smiles(selfies_str)
|
| 191 |
+
mol = Chem.MolFromSmiles(smiles) if smiles else None
|
| 192 |
+
|
| 193 |
+
rewards = {
|
| 194 |
+
"validity": 1.0 if mol is not None else 0.0,
|
| 195 |
+
"biological_diversity": compute_biological_diversity_score(mol),
|
| 196 |
+
"charge_neutrality": compute_charge_neutrality_score(mol),
|
| 197 |
+
"local_charge_penalty": compute_local_charge_penalty(mol),
|
| 198 |
+
"lipinski": compute_enhanced_lipinski_reward(mol),
|
| 199 |
+
"structural_complexity": compute_structural_complexity_reward(mol),
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
if rewards["validity"] == 0:
|
| 203 |
+
rewards["total"] = 0.0
|
| 204 |
+
else:
|
| 205 |
+
weights = {
|
| 206 |
+
"validity": 1.0,
|
| 207 |
+
"biological_diversity": 2.0,
|
| 208 |
+
"charge_neutrality": 1.5,
|
| 209 |
+
"local_charge_penalty": 1.0,
|
| 210 |
+
"lipinski": 1.0,
|
| 211 |
+
"structural_complexity": 0.5,
|
| 212 |
+
}
|
| 213 |
+
weighted_sum = sum(rewards[k] * weights[k] for k in weights)
|
| 214 |
+
rewards["total"] = weighted_sum / sum(weights.values())
|
| 215 |
+
|
| 216 |
+
return rewards
|
| 217 |
+
|
| 218 |
+
def compute_lipinski_reward(mol) -> float:
|
| 219 |
+
if mol is None:
|
| 220 |
+
return 0.0
|
| 221 |
+
try:
|
| 222 |
+
mw = Descriptors.MolWt(mol)
|
| 223 |
+
logp = Descriptors.MolLogP(mol)
|
| 224 |
+
hbd = Lipinski.NumHDonors(mol)
|
| 225 |
+
hba = Lipinski.NumHAcceptors(mol)
|
| 226 |
+
rules = [250 < mw <= 500, logp <= 5, hbd <= 5, hba <= 10] # we dont want too small of fragments
|
| 227 |
+
return sum(rules) / 4.0
|
| 228 |
+
except:
|
| 229 |
+
return 0.0
|
| 230 |
+
|
| 231 |
+
def selfies_to_lipinski_reward(selfies_str: str) -> float:
|
| 232 |
+
"""Convert SELFIES to SMILES, then compute Lipinski reward."""
|
| 233 |
+
smiles = selfies_to_smiles(selfies_str)
|
| 234 |
+
if smiles is None:
|
| 235 |
+
return 0.0
|
| 236 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 237 |
+
return compute_lipinski_reward(mol)
|
| 238 |
+
|
| 239 |
+
class MTPHead(nn.Module):
|
| 240 |
+
def __init__(self, hidden_size: int, vocab_size: int, num_future_tokens: int = 3):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.num_future_tokens = num_future_tokens
|
| 243 |
+
self.vocab_size = vocab_size
|
| 244 |
+
self.prediction_heads = nn.ModuleList([
|
| 245 |
+
nn.Linear(hidden_size, vocab_size, bias=False)
|
| 246 |
+
for _ in range(num_future_tokens)
|
| 247 |
+
])
|
| 248 |
+
self.position_embeddings = nn.Embedding(num_future_tokens, hidden_size)
|
| 249 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 250 |
+
|
| 251 |
+
def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 252 |
+
batch_size, seq_len, hidden_size = hidden_states.shape
|
| 253 |
+
outputs = {}
|
| 254 |
+
for i in range(self.num_future_tokens):
|
| 255 |
+
pos_emb = self.position_embeddings(torch.tensor(i, device=hidden_states.device))
|
| 256 |
+
enhanced_hidden = self.layer_norm(hidden_states + pos_emb)
|
| 257 |
+
logits = self.prediction_heads[i](enhanced_hidden)
|
| 258 |
+
outputs[f'logits_t{i+1}'] = logits
|
| 259 |
+
return outputs
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class HorizonLoss(nn.Module):
|
| 263 |
+
def __init__(self, num_future_tokens: int = 3, horizon_weights: Optional[List[float]] = None):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.num_future_tokens = num_future_tokens
|
| 266 |
+
if horizon_weights is None:
|
| 267 |
+
self.horizon_weights = [0.9 ** i for i in range(num_future_tokens)]
|
| 268 |
+
else:
|
| 269 |
+
self.horizon_weights = horizon_weights
|
| 270 |
+
self.log_weights = nn.Parameter(torch.log(torch.tensor(self.horizon_weights)))
|
| 271 |
+
|
| 272 |
+
def forward(self, mtp_outputs: Dict[str, torch.Tensor],
|
| 273 |
+
input_ids: torch.Tensor,
|
| 274 |
+
attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
| 275 |
+
batch_size, seq_len = input_ids.shape
|
| 276 |
+
device = input_ids.device
|
| 277 |
+
weights = F.softmax(self.log_weights, dim=0)
|
| 278 |
+
total_loss = 0.0
|
| 279 |
+
horizon_losses = {}
|
| 280 |
+
for i in range(self.num_future_tokens):
|
| 281 |
+
logits_key = f'logits_t{i+1}'
|
| 282 |
+
if logits_key not in mtp_outputs:
|
| 283 |
+
continue
|
| 284 |
+
logits = mtp_outputs[logits_key]
|
| 285 |
+
shift = i + 1
|
| 286 |
+
if seq_len <= shift:
|
| 287 |
+
continue
|
| 288 |
+
shifted_logits = logits[:, :-shift, :].contiguous()
|
| 289 |
+
shifted_targets = input_ids[:, shift:].contiguous()
|
| 290 |
+
if attention_mask is not None:
|
| 291 |
+
shifted_mask = attention_mask[:, shift:].contiguous()
|
| 292 |
+
mask_expanded = shifted_mask.view(-1)
|
| 293 |
+
valid_indices = mask_expanded == 1
|
| 294 |
+
if valid_indices.sum() == 0:
|
| 295 |
+
continue
|
| 296 |
+
flat_logits = shifted_logits.view(-1, logits.size(-1))[valid_indices]
|
| 297 |
+
flat_targets = shifted_targets.view(-1)[valid_indices]
|
| 298 |
+
else:
|
| 299 |
+
flat_logits = shifted_logits.view(-1, logits.size(-1))
|
| 300 |
+
flat_targets = shifted_targets.view(-1)
|
| 301 |
+
horizon_loss = F.cross_entropy(flat_logits, flat_targets, reduction='mean')
|
| 302 |
+
horizon_losses[f'horizon_loss_t{i+1}'] = horizon_loss
|
| 303 |
+
total_loss += weights[i] * horizon_loss
|
| 304 |
+
return {'loss': total_loss, 'horizon_weights': weights, **horizon_losses}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class ChemQ3MTP(Qwen3ForCausalLM):
|
| 308 |
+
def __init__(self, config, num_future_tokens: int = 3):
|
| 309 |
+
super().__init__(config)
|
| 310 |
+
self.mtp_head = MTPHead(config.hidden_size, config.vocab_size, num_future_tokens)
|
| 311 |
+
self.horizon_loss = HorizonLoss(num_future_tokens=num_future_tokens)
|
| 312 |
+
self.use_mtp_training = True
|
| 313 |
+
self.post_init()
|
| 314 |
+
self.entropy_controller = EnhancedEntropyController(
|
| 315 |
+
min_entropy=0.5,
|
| 316 |
+
max_entropy=3.0,
|
| 317 |
+
target_entropy=1.5,
|
| 318 |
+
adaptation_rate=0.01,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def forward(
|
| 323 |
+
self,
|
| 324 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 325 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 326 |
+
labels: Optional[torch.LongTensor] = None,
|
| 327 |
+
**kwargs
|
| 328 |
+
):
|
| 329 |
+
# Default mask if not provided
|
| 330 |
+
if attention_mask is None and input_ids is not None:
|
| 331 |
+
attention_mask = (input_ids != self.config.pad_token_id).long()
|
| 332 |
+
|
| 333 |
+
# Respect caller settings, only set defaults if missing
|
| 334 |
+
kwargs.setdefault("output_hidden_states", True)
|
| 335 |
+
kwargs.setdefault("return_dict", True)
|
| 336 |
+
|
| 337 |
+
outputs = super().forward(
|
| 338 |
+
input_ids=input_ids,
|
| 339 |
+
attention_mask=attention_mask,
|
| 340 |
+
labels=None,
|
| 341 |
+
**kwargs
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
hidden_states = outputs.hidden_states[-1]
|
| 345 |
+
lm_logits = outputs.logits
|
| 346 |
+
loss = None
|
| 347 |
+
|
| 348 |
+
if self.training and self.use_mtp_training and labels is not None: # labels, not kwargs
|
| 349 |
+
mtp_outputs = self.mtp_head(hidden_states)
|
| 350 |
+
horizon_loss_dict = self.horizon_loss(mtp_outputs, input_ids, attention_mask)
|
| 351 |
+
|
| 352 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 353 |
+
shift_labels = labels[..., 1:].contiguous() # labels, not kwargs["labels"]
|
| 354 |
+
|
| 355 |
+
if attention_mask is not None:
|
| 356 |
+
shift_mask = attention_mask[..., 1:].contiguous()
|
| 357 |
+
loss_mask = shift_mask.view(-1) == 1
|
| 358 |
+
if loss_mask.sum() == 0:
|
| 359 |
+
causal_lm_loss = torch.tensor(0.0, device=lm_logits.device)
|
| 360 |
+
else:
|
| 361 |
+
flat_logits = shift_logits.view(-1, shift_logits.size(-1))[loss_mask]
|
| 362 |
+
flat_labels = shift_labels.view(-1)[loss_mask]
|
| 363 |
+
causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
|
| 364 |
+
else:
|
| 365 |
+
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 366 |
+
flat_labels = shift_labels.view(-1)
|
| 367 |
+
causal_lm_loss = F.cross_entropy(flat_logits, flat_labels, reduction='mean')
|
| 368 |
+
|
| 369 |
+
loss = 0.7 * horizon_loss_dict['loss'] + 0.3 * causal_lm_loss
|
| 370 |
+
|
| 371 |
+
elif labels is not None: # labels, not kwargs.get("labels")
|
| 372 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 373 |
+
shift_labels = labels[..., 1:].contiguous() # labels, not kwargs["labels"]
|
| 374 |
+
loss = F.cross_entropy(
|
| 375 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 376 |
+
shift_labels.view(-1)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 380 |
+
return CausalLMOutputWithPast(
|
| 381 |
+
loss=loss,
|
| 382 |
+
logits=lm_logits,
|
| 383 |
+
past_key_values=outputs.past_key_values,
|
| 384 |
+
hidden_states=outputs.hidden_states,
|
| 385 |
+
attentions=outputs.attentions,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def set_mtp_training(self, use_mtp: bool):
|
| 389 |
+
self.use_mtp_training = use_mtp
|
| 390 |
+
|
| 391 |
+
# ================
|
| 392 |
+
# RL SAMPLING + PPO
|
| 393 |
+
# ================
|
| 394 |
+
|
| 395 |
+
def generate_with_logprobs(
|
| 396 |
+
self,
|
| 397 |
+
input_ids: torch.LongTensor,
|
| 398 |
+
max_new_tokens: int = 50,
|
| 399 |
+
temperature: float = 1.0,
|
| 400 |
+
top_k: Optional[int] = None,
|
| 401 |
+
top_p: Optional[float] = None,
|
| 402 |
+
do_sample: bool = True,
|
| 403 |
+
return_probs: bool = True,
|
| 404 |
+
tokenizer=None, # allow passing explicitly
|
| 405 |
+
) -> Tuple[List[str], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 406 |
+
self.eval()
|
| 407 |
+
device = input_ids.device
|
| 408 |
+
|
| 409 |
+
# Normalize shapes: allow [L], [1,L], [B,L], [B,1,L]
|
| 410 |
+
if input_ids.dim() == 1:
|
| 411 |
+
input_ids = input_ids.unsqueeze(0) # [L] -> [1,L]
|
| 412 |
+
if input_ids.dim() == 3 and input_ids.size(1) == 1:
|
| 413 |
+
input_ids = input_ids.squeeze(1) # [B,1,L] -> [B,L]
|
| 414 |
+
assert input_ids.dim() == 2, f"input_ids must be 2-D, got {input_ids.shape}"
|
| 415 |
+
|
| 416 |
+
batch_size, seq_len = input_ids.shape
|
| 417 |
+
current_input = input_ids
|
| 418 |
+
|
| 419 |
+
generated_tokens, generated_logprobs, generated_probs = [], [], []
|
| 420 |
+
|
| 421 |
+
with torch.no_grad():
|
| 422 |
+
for _ in range(max_new_tokens):
|
| 423 |
+
outputs = self(current_input, use_cache=False)
|
| 424 |
+
logits = outputs.logits[:, -1, :] / temperature
|
| 425 |
+
|
| 426 |
+
# Top-k
|
| 427 |
+
if top_k is not None:
|
| 428 |
+
values, indices = torch.topk(logits, k=top_k)
|
| 429 |
+
logits = torch.full_like(logits, float("-inf"))
|
| 430 |
+
logits.scatter_(1, indices, values)
|
| 431 |
+
|
| 432 |
+
# Top-p
|
| 433 |
+
if top_p is not None and top_p < 1.0:
|
| 434 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 435 |
+
cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 436 |
+
mask = cumprobs > top_p
|
| 437 |
+
mask[..., 1:] = mask[..., :-1].clone()
|
| 438 |
+
mask[..., 0] = False
|
| 439 |
+
logits[mask.scatter(1, sorted_indices, mask)] = float("-inf")
|
| 440 |
+
|
| 441 |
+
probs = F.softmax(logits, dim=-1)
|
| 442 |
+
|
| 443 |
+
if do_sample:
|
| 444 |
+
dist = Categorical(probs)
|
| 445 |
+
next_token = dist.sample()
|
| 446 |
+
log_p = dist.log_prob(next_token)
|
| 447 |
+
else:
|
| 448 |
+
next_token = torch.argmax(probs, dim=-1)
|
| 449 |
+
log_p = torch.log(torch.gather(probs, 1, next_token.unsqueeze(1))).squeeze(1)
|
| 450 |
+
|
| 451 |
+
generated_tokens.append(next_token.unsqueeze(1))
|
| 452 |
+
generated_logprobs.append(log_p.unsqueeze(1))
|
| 453 |
+
if return_probs:
|
| 454 |
+
generated_probs.append(probs.unsqueeze(1))
|
| 455 |
+
|
| 456 |
+
current_input = torch.cat([current_input, next_token.unsqueeze(1)], dim=1)
|
| 457 |
+
|
| 458 |
+
generated_tokens = torch.cat(generated_tokens, dim=1) # [B, T]
|
| 459 |
+
generated_logprobs = torch.cat(generated_logprobs, dim=1) # [B, T]
|
| 460 |
+
generated_probs = torch.cat(generated_probs, dim=1) if return_probs else None
|
| 461 |
+
|
| 462 |
+
# Use passed tokenizer, fallback to self.tokenizer
|
| 463 |
+
tok = tokenizer if tokenizer is not None else getattr(self, "tokenizer", None)
|
| 464 |
+
if tok is None:
|
| 465 |
+
raise ValueError("Tokenizer must be provided to decode generated tokens.")
|
| 466 |
+
|
| 467 |
+
decoded_list = [
|
| 468 |
+
tok.decode(tok_ids, skip_special_tokens=True)
|
| 469 |
+
for tok_ids in generated_tokens
|
| 470 |
+
]
|
| 471 |
+
return decoded_list, generated_logprobs, generated_tokens, generated_probs
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def ppo_step(
|
| 475 |
+
self,
|
| 476 |
+
input_ids: torch.LongTensor,
|
| 477 |
+
old_log_probs: torch.Tensor,
|
| 478 |
+
old_action_probs: torch.Tensor,
|
| 479 |
+
tokenizer,
|
| 480 |
+
max_new_tokens: int = 50,
|
| 481 |
+
temperature: float = 1.0,
|
| 482 |
+
top_k: Optional[int] = 50,
|
| 483 |
+
top_p: Optional[float] = 0.95,
|
| 484 |
+
validity_weight: float = 1.0,
|
| 485 |
+
lipinski_weight: float = 1.0,
|
| 486 |
+
entropy_weight: float = 0.01,
|
| 487 |
+
clip_epsilon: float = 0.2,
|
| 488 |
+
baseline: Optional[torch.Tensor] = None,
|
| 489 |
+
reward_mode: str = "chemq3", # "chemq3", "sa", or "mix"
|
| 490 |
+
reward_mix: float = 0.5, # used if mixing chemq3 + sa (0..1 weight for chemq3)
|
| 491 |
+
) -> Dict[str, Any]:
|
| 492 |
+
|
| 493 |
+
# =========================
|
| 494 |
+
# PPO-KL BODY (drop-in)
|
| 495 |
+
# =========================
|
| 496 |
+
self.train()
|
| 497 |
+
self.set_mtp_training(False)
|
| 498 |
+
if not hasattr(self, 'tokenizer'):
|
| 499 |
+
self.tokenizer = tokenizer
|
| 500 |
+
|
| 501 |
+
# Ensure entropy controller exists
|
| 502 |
+
if not hasattr(self, 'entropy_controller'):
|
| 503 |
+
# if you want different defaults, set them when constructing model instead
|
| 504 |
+
self.entropy_controller = EnhancedEntropyController(
|
| 505 |
+
min_entropy=0.5,
|
| 506 |
+
max_entropy=3.0,
|
| 507 |
+
target_entropy=1.5,
|
| 508 |
+
adaptation_rate=0.01
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# --- roll-out ---
|
| 512 |
+
selfies_list, new_log_probs, token_ids, new_action_probs = self.generate_with_logprobs(
|
| 513 |
+
input_ids=input_ids,
|
| 514 |
+
max_new_tokens=max_new_tokens,
|
| 515 |
+
temperature=temperature,
|
| 516 |
+
top_k=top_k,
|
| 517 |
+
top_p=top_p,
|
| 518 |
+
do_sample=True,
|
| 519 |
+
return_probs=True,
|
| 520 |
+
tokenizer=getattr(self, "tokenizer", None),
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
batch_size = len(selfies_list)
|
| 524 |
+
device = new_log_probs.device
|
| 525 |
+
|
| 526 |
+
# --- rewards: compute depending on mode ---
|
| 527 |
+
validity_vals: List[float] = []
|
| 528 |
+
lipinski_vals: List[float] = []
|
| 529 |
+
total_rewards: List[float] = []
|
| 530 |
+
sa_rewards: List[float] = []
|
| 531 |
+
|
| 532 |
+
for s in selfies_list:
|
| 533 |
+
if reward_mode == "chemq3":
|
| 534 |
+
r = compute_comprehensive_reward(s)
|
| 535 |
+
validity_vals.append(r.get('validity', 0.0))
|
| 536 |
+
lipinski_vals.append(r.get('lipinski', 0.0))
|
| 537 |
+
total_rewards.append(r.get('total', 0.0))
|
| 538 |
+
|
| 539 |
+
elif reward_mode == "sa":
|
| 540 |
+
sa = compute_sa_reward(s)
|
| 541 |
+
sa_rewards.append(sa)
|
| 542 |
+
|
| 543 |
+
elif reward_mode == "mix":
|
| 544 |
+
r = compute_comprehensive_reward(s)
|
| 545 |
+
sa = compute_sa_reward(s)
|
| 546 |
+
mixed = reward_mix * r.get("total", 0.0) + (1.0 - reward_mix) * sa
|
| 547 |
+
total_rewards.append(mixed)
|
| 548 |
+
sa_rewards.append(sa)
|
| 549 |
+
validity_vals.append(r.get('validity', 0.0))
|
| 550 |
+
lipinski_vals.append(r.get('lipinski', 0.0))
|
| 551 |
+
|
| 552 |
+
else:
|
| 553 |
+
# unknown mode -> default to zero reward
|
| 554 |
+
total_rewards.append(0.0)
|
| 555 |
+
validity_vals.append(0.0)
|
| 556 |
+
lipinski_vals.append(0.0)
|
| 557 |
+
|
| 558 |
+
# Convert lists -> tensors, handle empty lists safely
|
| 559 |
+
if reward_mode in ("chemq3", "mix"):
|
| 560 |
+
rewards = torch.tensor(total_rewards, dtype=torch.float32, device=device)
|
| 561 |
+
elif reward_mode == "sa":
|
| 562 |
+
rewards = torch.tensor(sa_rewards, dtype=torch.float32, device=device)
|
| 563 |
+
else:
|
| 564 |
+
rewards = torch.zeros(batch_size, dtype=torch.float32, device=device)
|
| 565 |
+
|
| 566 |
+
if len(validity_vals) > 0:
|
| 567 |
+
validity_rewards = torch.tensor(validity_vals, dtype=torch.float32, device=device)
|
| 568 |
+
else:
|
| 569 |
+
validity_rewards = torch.zeros(batch_size, dtype=torch.float32, device=device)
|
| 570 |
+
|
| 571 |
+
if len(lipinski_vals) > 0:
|
| 572 |
+
lipinski_rewards = torch.tensor(lipinski_vals, dtype=torch.float32, device=device)
|
| 573 |
+
else:
|
| 574 |
+
lipinski_rewards = torch.zeros(batch_size, dtype=torch.float32, device=device)
|
| 575 |
+
|
| 576 |
+
# baseline subtraction (broadcast if needed)
|
| 577 |
+
if baseline is not None:
|
| 578 |
+
# baseline can be scalar tensor or per-batch; support both
|
| 579 |
+
if baseline.numel() == 1:
|
| 580 |
+
rewards = rewards - baseline.to(device)
|
| 581 |
+
else:
|
| 582 |
+
rewards = rewards - baseline.to(device)
|
| 583 |
+
|
| 584 |
+
# --- probability ratio ---
|
| 585 |
+
# old_action_probs/new_action_probs expected shape: [B, T, V]
|
| 586 |
+
# token_ids expected shape: [B, T]
|
| 587 |
+
old_probs = torch.gather(old_action_probs, 2, token_ids.unsqueeze(2)).squeeze(2).clamp_min(1e-8)
|
| 588 |
+
new_probs = torch.gather(new_action_probs, 2, token_ids.unsqueeze(2)).squeeze(2).clamp_min(1e-8)
|
| 589 |
+
log_ratio = new_log_probs - old_log_probs # shape [B, T]
|
| 590 |
+
# total_ratio: product of per-step ratios -> exp(sum(log ratio))
|
| 591 |
+
total_ratio = torch.exp(log_ratio.sum(dim=1)) # shape [B]
|
| 592 |
+
|
| 593 |
+
# --- adaptive KL controller (singleton) ---
|
| 594 |
+
if not hasattr(self, 'kl_controller'):
|
| 595 |
+
self.kl_controller = AdaptiveKLController()
|
| 596 |
+
# KL per example: sum over time of old * (log old - log new), averaged over V already via gather
|
| 597 |
+
# Here compute KL between full distributions if available
|
| 598 |
+
kl = (old_probs * (torch.log(old_probs) - torch.log(new_probs))).sum(dim=1) # shape [B]
|
| 599 |
+
beta = self.kl_controller.update(kl.mean().item())
|
| 600 |
+
|
| 601 |
+
# --- PPO-KL loss ---
|
| 602 |
+
surr1 = total_ratio * rewards
|
| 603 |
+
surr2 = torch.clamp(total_ratio, 1 - clip_epsilon, 1 + clip_epsilon) * rewards
|
| 604 |
+
ppo_loss = -torch.min(surr1, surr2).mean()
|
| 605 |
+
kl_penalty = beta * kl.mean()
|
| 606 |
+
total_policy_loss = ppo_loss + kl_penalty
|
| 607 |
+
|
| 608 |
+
# --- entropy bonus (adaptive) ---
|
| 609 |
+
# compute token-level entropy averaged across batch/time
|
| 610 |
+
with torch.no_grad():
|
| 611 |
+
_probs = new_action_probs.clamp_min(1e-12)
|
| 612 |
+
per_step_entropy = -(_probs * torch.log(_probs)).sum(dim=-1) # [B, T]
|
| 613 |
+
entropy = per_step_entropy.mean() # scalar tensor
|
| 614 |
+
|
| 615 |
+
adaptive_entropy_weight = self.entropy_controller.update_entropy_weight(entropy.item())
|
| 616 |
+
entropy_bonus = adaptive_entropy_weight * entropy
|
| 617 |
+
total_loss = total_policy_loss - entropy_bonus
|
| 618 |
+
|
| 619 |
+
# regularization (optional)
|
| 620 |
+
reg_loss = 1e-7 * sum(p.pow(2).sum() for p in self.parameters())
|
| 621 |
+
total_loss = total_loss + reg_loss
|
| 622 |
+
|
| 623 |
+
# prepare return (detach tensors where relevant)
|
| 624 |
+
avg_sa = None
|
| 625 |
+
if len(sa_rewards) > 0:
|
| 626 |
+
avg_sa = float(torch.tensor(sa_rewards, dtype=torch.float32, device=device).mean().item())
|
| 627 |
+
|
| 628 |
+
return {
|
| 629 |
+
'loss': total_loss,
|
| 630 |
+
'ppo_loss': ppo_loss.item(),
|
| 631 |
+
'kl_penalty': kl_penalty.item(),
|
| 632 |
+
'kl_coef': beta,
|
| 633 |
+
'entropy': float(entropy.item()),
|
| 634 |
+
'entropy_weight': float(adaptive_entropy_weight),
|
| 635 |
+
'validity_rate': float(validity_rewards.mean().item()),
|
| 636 |
+
'lipinski_score': float(lipinski_rewards.mean().item()),
|
| 637 |
+
'avg_reward': float(rewards.mean().item()),
|
| 638 |
+
'avg_sa_reward': avg_sa,
|
| 639 |
+
'generated_selfies': selfies_list,
|
| 640 |
+
'generated_smiles': [selfies_to_smiles(s) for s in selfies_list],
|
| 641 |
+
'new_log_probs': new_log_probs.detach(),
|
| 642 |
+
'new_action_probs': new_action_probs.detach(),
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
# ========================
|
| 649 |
+
# CURRICULUM LEARNING MANAGER
|
| 650 |
+
# ========================
|
| 651 |
+
|
| 652 |
+
class CurriculumManager:
|
| 653 |
+
def __init__(self, start_len=10, max_len=30, step_increase=5, steps_per_level=30):
|
| 654 |
+
"""
|
| 655 |
+
Cyclic curriculum:
|
| 656 |
+
- Gradually increases max_new_tokens from start_len → max_len
|
| 657 |
+
- After reaching max_len, resets back to start_len and repeats
|
| 658 |
+
"""
|
| 659 |
+
self.start_len = start_len
|
| 660 |
+
self.max_len = max_len
|
| 661 |
+
self.step_increase = step_increase
|
| 662 |
+
self.steps_per_level = steps_per_level
|
| 663 |
+
self.step_counter = 0
|
| 664 |
+
self.current_max_len = start_len
|
| 665 |
+
|
| 666 |
+
def get_max_new_tokens(self):
|
| 667 |
+
return self.current_max_len
|
| 668 |
+
|
| 669 |
+
def step(self):
|
| 670 |
+
self.step_counter += 1
|
| 671 |
+
if self.step_counter % self.steps_per_level == 0:
|
| 672 |
+
if self.current_max_len < self.max_len:
|
| 673 |
+
self.current_max_len = min(self.current_max_len + self.step_increase, self.max_len)
|
| 674 |
+
else:
|
| 675 |
+
# Reset cycle
|
| 676 |
+
self.current_max_len = self.start_len
|
| 677 |
+
print(f" 🔄 Cycle reset: max_new_tokens -> {self.current_max_len}")
|
| 678 |
+
if self.current_max_len < self.max_len:
|
| 679 |
+
print(f" 📈 Curriculum Update: max_new_tokens = {self.current_max_len}")
|
| 680 |
+
return self.current_max_len
|
| 681 |
+
|
| 682 |
+
class AdaptiveKLController:
|
| 683 |
+
"""
|
| 684 |
+
Increases or decreases β so that E[KL] stays ≈ target_kl.
|
| 685 |
+
"""
|
| 686 |
+
def __init__(self, init_kl_coef: float = 0.1, target_kl: float = 0.01,
|
| 687 |
+
kl_horizon: int = 1000, increase_rate: float = 1.5, decrease_rate: float = 0.8):
|
| 688 |
+
self.kl_coef = init_kl_coef
|
| 689 |
+
self.target_kl = target_kl
|
| 690 |
+
self.kl_horizon = kl_horizon
|
| 691 |
+
self.inc = increase_rate
|
| 692 |
+
self.dec = decrease_rate
|
| 693 |
+
self.buffer = []
|
| 694 |
+
|
| 695 |
+
def update(self, kl: float):
|
| 696 |
+
self.buffer.append(kl)
|
| 697 |
+
if len(self.buffer) >= self.kl_horizon:
|
| 698 |
+
avg_kl = sum(self.buffer) / len(self.buffer)
|
| 699 |
+
self.buffer.clear()
|
| 700 |
+
if avg_kl > self.target_kl * 1.5:
|
| 701 |
+
self.kl_coef *= self.inc
|
| 702 |
+
elif avg_kl < self.target_kl * 0.5:
|
| 703 |
+
self.kl_coef *= self.dec
|
| 704 |
+
return self.kl_coef
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class EnhancedEntropyController:
|
| 708 |
+
"""
|
| 709 |
+
More sophisticated entropy control with dynamic targets and temperature scheduling.
|
| 710 |
+
"""
|
| 711 |
+
def __init__(self, min_entropy: float = 0.5, max_entropy: float = 3.0,
|
| 712 |
+
target_entropy: float = 1.5, adaptation_rate: float = 0.01):
|
| 713 |
+
self.min_entropy = min_entropy
|
| 714 |
+
self.max_entropy = max_entropy
|
| 715 |
+
self.target_entropy = target_entropy
|
| 716 |
+
self.adaptation_rate = adaptation_rate
|
| 717 |
+
self.entropy_history = []
|
| 718 |
+
self.entropy_weight = 0.01 # Starting weight
|
| 719 |
+
|
| 720 |
+
def update_entropy_weight(self, current_entropy: float) -> float:
|
| 721 |
+
"""
|
| 722 |
+
Dynamically adjust entropy weight based on current entropy levels.
|
| 723 |
+
"""
|
| 724 |
+
self.entropy_history.append(current_entropy)
|
| 725 |
+
|
| 726 |
+
# Keep rolling window
|
| 727 |
+
if len(self.entropy_history) > 100:
|
| 728 |
+
self.entropy_history = self.entropy_history[-100:]
|
| 729 |
+
|
| 730 |
+
if len(self.entropy_history) >= 10:
|
| 731 |
+
avg_entropy = np.mean(self.entropy_history[-10:])
|
| 732 |
+
|
| 733 |
+
# If entropy too low, increase weight to encourage exploration
|
| 734 |
+
if avg_entropy < self.target_entropy * 0.8:
|
| 735 |
+
self.entropy_weight = min(0.05, self.entropy_weight * 1.1)
|
| 736 |
+
# If entropy too high, decrease weight
|
| 737 |
+
elif avg_entropy > self.target_entropy * 1.2:
|
| 738 |
+
self.entropy_weight = max(0.001, self.entropy_weight * 0.95)
|
| 739 |
+
|
| 740 |
+
return self.entropy_weight
|
| 741 |
+
|
| 742 |
+
def compute_entropy_reward(self, entropy: float) -> float:
|
| 743 |
+
"""
|
| 744 |
+
Reward function for entropy - prefer target range.
|
| 745 |
+
"""
|
| 746 |
+
if self.min_entropy <= entropy <= self.max_entropy:
|
| 747 |
+
# Gaussian reward centered at target
|
| 748 |
+
distance = abs(entropy - self.target_entropy)
|
| 749 |
+
max_distance = max(self.target_entropy - self.min_entropy,
|
| 750 |
+
self.max_entropy - self.target_entropy)
|
| 751 |
+
return np.exp(-(distance / max_distance) ** 2)
|
| 752 |
+
else:
|
| 753 |
+
return 0.1 # Small penalty for being outside range
|
FastChemTokenizerHF.py
ADDED
|
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Union, Optional, Tuple, Dict, Any
|
| 5 |
+
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
|
| 6 |
+
from transformers.utils import PaddingStrategy, TensorType
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TrieNode:
|
| 11 |
+
__slots__ = ['children', 'token_id']
|
| 12 |
+
def __init__(self):
|
| 13 |
+
self.children = {}
|
| 14 |
+
self.token_id = None # If set, this node completes a valid token
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FastChemTokenizer(PreTrainedTokenizerBase):
|
| 18 |
+
"""
|
| 19 |
+
Fully HuggingFace API compatible tokenizer for chemical representations.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
vocab_files_names = {"vocab_file": "vocab.json"}
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
token_to_id=None,
|
| 27 |
+
vocab_file=None,
|
| 28 |
+
model_max_length=512,
|
| 29 |
+
padding_side="right",
|
| 30 |
+
truncation_side="right",
|
| 31 |
+
chat_template=None,
|
| 32 |
+
**kwargs
|
| 33 |
+
):
|
| 34 |
+
# Handle vocab loading
|
| 35 |
+
if token_to_id is None and vocab_file is None:
|
| 36 |
+
raise ValueError("Either token_to_id or vocab_file must be provided")
|
| 37 |
+
|
| 38 |
+
if vocab_file is not None:
|
| 39 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
| 40 |
+
token_to_id = json.load(f)
|
| 41 |
+
token_to_id = {str(k): int(v) for k, v in token_to_id.items()}
|
| 42 |
+
|
| 43 |
+
self.token_to_id = token_to_id
|
| 44 |
+
self.id_to_token = {v: k for k, v in token_to_id.items()}
|
| 45 |
+
|
| 46 |
+
# Precompute max token length for possible use & clarity
|
| 47 |
+
self.max_token_len = max(len(t) for t in token_to_id.keys()) if token_to_id else 0
|
| 48 |
+
|
| 49 |
+
# Build trie for fast longest-match lookup
|
| 50 |
+
self.trie_root = self._build_trie(token_to_id)
|
| 51 |
+
|
| 52 |
+
# Validate required special tokens
|
| 53 |
+
required_special_tokens = ["<s>", "</s>", "<pad>", "<unk>", "<mask>"]
|
| 54 |
+
for tok in required_special_tokens:
|
| 55 |
+
if tok not in token_to_id:
|
| 56 |
+
raise KeyError(f"Required special token '{tok}' not found in vocab.")
|
| 57 |
+
|
| 58 |
+
# ✅ Assign special token IDs explicitly
|
| 59 |
+
self.bos_token_id = token_to_id["<s>"]
|
| 60 |
+
self.eos_token_id = token_to_id["</s>"]
|
| 61 |
+
self.pad_token_id = token_to_id["<pad>"]
|
| 62 |
+
self.unk_token_id = token_to_id["<unk>"]
|
| 63 |
+
self.mask_token_id = token_to_id["<mask>"]
|
| 64 |
+
|
| 65 |
+
# Special tokens
|
| 66 |
+
bos_token = "<s>"
|
| 67 |
+
eos_token = "</s>"
|
| 68 |
+
pad_token = "<pad>"
|
| 69 |
+
unk_token = "<unk>"
|
| 70 |
+
mask_token = "<mask>"
|
| 71 |
+
|
| 72 |
+
# Initialize parent class with all required parameters
|
| 73 |
+
super().__init__(
|
| 74 |
+
bos_token=bos_token,
|
| 75 |
+
eos_token=eos_token,
|
| 76 |
+
unk_token=unk_token,
|
| 77 |
+
sep_token=None,
|
| 78 |
+
pad_token=pad_token,
|
| 79 |
+
cls_token=None,
|
| 80 |
+
mask_token=mask_token,
|
| 81 |
+
additional_special_tokens=[],
|
| 82 |
+
model_max_length=model_max_length,
|
| 83 |
+
padding_side=padding_side,
|
| 84 |
+
truncation_side=truncation_side,
|
| 85 |
+
chat_template=chat_template,
|
| 86 |
+
**kwargs,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def _build_trie(self, token_to_id):
|
| 90 |
+
root = TrieNode()
|
| 91 |
+
for token, tid in token_to_id.items():
|
| 92 |
+
node = root
|
| 93 |
+
for char in token:
|
| 94 |
+
if char not in node.children:
|
| 95 |
+
node.children[char] = TrieNode()
|
| 96 |
+
node = node.children[char]
|
| 97 |
+
node.token_id = tid
|
| 98 |
+
return root
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def vocab_size(self):
|
| 102 |
+
return len(self.token_to_id)
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return len(self.token_to_id)
|
| 106 |
+
|
| 107 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 108 |
+
return self.token_to_id.copy()
|
| 109 |
+
|
| 110 |
+
@lru_cache(maxsize=10000)
|
| 111 |
+
def _cached_encode_str(self, s: str) -> Tuple[int, ...]:
|
| 112 |
+
return tuple(self._encode_core(s))
|
| 113 |
+
|
| 114 |
+
def _encode_core(self, text: str) -> List[int]:
|
| 115 |
+
"""Core encoding logic using Trie — no caching."""
|
| 116 |
+
tokens = text
|
| 117 |
+
result_ids = []
|
| 118 |
+
i = 0
|
| 119 |
+
n = len(tokens)
|
| 120 |
+
|
| 121 |
+
while i < n:
|
| 122 |
+
node = self.trie_root
|
| 123 |
+
j = i
|
| 124 |
+
last_match_id = None
|
| 125 |
+
last_match_end = i
|
| 126 |
+
|
| 127 |
+
while j < n and tokens[j] in node.children:
|
| 128 |
+
node = node.children[tokens[j]]
|
| 129 |
+
j += 1
|
| 130 |
+
if node.token_id is not None:
|
| 131 |
+
last_match_id = node.token_id
|
| 132 |
+
last_match_end = j
|
| 133 |
+
|
| 134 |
+
if last_match_id is not None:
|
| 135 |
+
result_ids.append(last_match_id)
|
| 136 |
+
i = last_match_end
|
| 137 |
+
else:
|
| 138 |
+
tok = tokens[i]
|
| 139 |
+
result_ids.append(self.token_to_id.get(tok, self.unk_token_id))
|
| 140 |
+
i += 1
|
| 141 |
+
|
| 142 |
+
return result_ids
|
| 143 |
+
|
| 144 |
+
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
| 145 |
+
token_ids = self._encode_core(text.strip())
|
| 146 |
+
return [self.id_to_token[tid] for tid in token_ids]
|
| 147 |
+
|
| 148 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 149 |
+
return self.token_to_id.get(token, self.unk_token_id)
|
| 150 |
+
|
| 151 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 152 |
+
return self.id_to_token.get(index, self.unk_token)
|
| 153 |
+
|
| 154 |
+
# ✅ Public methods
|
| 155 |
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
| 156 |
+
if isinstance(tokens, str):
|
| 157 |
+
return self._convert_token_to_id(tokens)
|
| 158 |
+
return [self._convert_token_to_id(tok) for tok in tokens]
|
| 159 |
+
|
| 160 |
+
def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
|
| 161 |
+
if isinstance(ids, int):
|
| 162 |
+
return self._convert_id_to_token(ids)
|
| 163 |
+
return [self._convert_id_to_token(i) for i in ids]
|
| 164 |
+
|
| 165 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 166 |
+
"""SMILES-style decoding: no spaces between tokens."""
|
| 167 |
+
return "".join(tokens)
|
| 168 |
+
|
| 169 |
+
def encode(
|
| 170 |
+
self,
|
| 171 |
+
text: str,
|
| 172 |
+
text_pair: Optional[str] = None,
|
| 173 |
+
add_special_tokens: bool = True,
|
| 174 |
+
padding: bool = False,
|
| 175 |
+
truncation: bool = False,
|
| 176 |
+
max_length: Optional[int] = None,
|
| 177 |
+
return_tensors: Optional[str] = None,
|
| 178 |
+
) -> List[int]:
|
| 179 |
+
encoded = self.encode_plus(
|
| 180 |
+
text=text,
|
| 181 |
+
text_pair=text_pair,
|
| 182 |
+
add_special_tokens=add_special_tokens,
|
| 183 |
+
padding=padding,
|
| 184 |
+
truncation=truncation,
|
| 185 |
+
max_length=max_length,
|
| 186 |
+
return_tensors=return_tensors,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
input_ids = encoded["input_ids"]
|
| 190 |
+
if isinstance(input_ids, torch.Tensor):
|
| 191 |
+
if input_ids.dim() > 1:
|
| 192 |
+
input_ids = input_ids.squeeze(0)
|
| 193 |
+
input_ids = input_ids.tolist()
|
| 194 |
+
|
| 195 |
+
return input_ids
|
| 196 |
+
|
| 197 |
+
def decode(
|
| 198 |
+
self,
|
| 199 |
+
token_ids: Union[List[int], torch.Tensor],
|
| 200 |
+
skip_special_tokens: bool = False,
|
| 201 |
+
clean_up_tokenization_spaces: bool = None,
|
| 202 |
+
**kwargs
|
| 203 |
+
) -> str:
|
| 204 |
+
if isinstance(token_ids, torch.Tensor):
|
| 205 |
+
token_ids = token_ids.tolist()
|
| 206 |
+
|
| 207 |
+
if skip_special_tokens:
|
| 208 |
+
special_ids = {
|
| 209 |
+
self.bos_token_id,
|
| 210 |
+
self.eos_token_id,
|
| 211 |
+
self.pad_token_id,
|
| 212 |
+
self.mask_token_id,
|
| 213 |
+
}
|
| 214 |
+
else:
|
| 215 |
+
special_ids = set()
|
| 216 |
+
|
| 217 |
+
tokens = []
|
| 218 |
+
for tid in token_ids:
|
| 219 |
+
if tid in special_ids:
|
| 220 |
+
continue
|
| 221 |
+
token = self.id_to_token.get(tid, self.unk_token)
|
| 222 |
+
tokens.append(token)
|
| 223 |
+
|
| 224 |
+
return "".join(tokens)
|
| 225 |
+
|
| 226 |
+
def batch_decode(
|
| 227 |
+
self,
|
| 228 |
+
sequences: Union[List[List[int]], torch.Tensor],
|
| 229 |
+
skip_special_tokens: bool = False,
|
| 230 |
+
clean_up_tokenization_spaces: bool = None,
|
| 231 |
+
**kwargs
|
| 232 |
+
) -> List[str]:
|
| 233 |
+
"""Batch decode sequences."""
|
| 234 |
+
if isinstance(sequences, torch.Tensor):
|
| 235 |
+
sequences = sequences.tolist()
|
| 236 |
+
|
| 237 |
+
return [
|
| 238 |
+
self.decode(
|
| 239 |
+
seq,
|
| 240 |
+
skip_special_tokens=skip_special_tokens,
|
| 241 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 242 |
+
**kwargs
|
| 243 |
+
)
|
| 244 |
+
for seq in sequences
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
def decode_with_trace(self, token_ids: List[int]) -> None:
|
| 248 |
+
print(f"\n🔍 Decoding {len(token_ids)} tokens:")
|
| 249 |
+
for i, tid in enumerate(token_ids):
|
| 250 |
+
token = self.id_to_token.get(tid, self.unk_token)
|
| 251 |
+
print(f" [{i:03d}] ID={tid:5d} → '{token}'")
|
| 252 |
+
|
| 253 |
+
def __call__(
|
| 254 |
+
self,
|
| 255 |
+
text: Union[str, List[str]],
|
| 256 |
+
text_pair: Optional[Union[str, List[str]]] = None,
|
| 257 |
+
add_special_tokens: bool = True,
|
| 258 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 259 |
+
truncation: Union[bool, str] = False,
|
| 260 |
+
max_length: Optional[int] = None,
|
| 261 |
+
stride: int = 0,
|
| 262 |
+
is_split_into_words: bool = False,
|
| 263 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 264 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 265 |
+
return_token_type_ids: Optional[bool] = None,
|
| 266 |
+
return_attention_mask: Optional[bool] = None,
|
| 267 |
+
return_overflowing_tokens: bool = False,
|
| 268 |
+
return_special_tokens_mask: bool = False,
|
| 269 |
+
return_offsets_mapping: bool = False,
|
| 270 |
+
return_length: bool = False,
|
| 271 |
+
verbose: bool = True,
|
| 272 |
+
**kwargs
|
| 273 |
+
) -> BatchEncoding:
|
| 274 |
+
"""
|
| 275 |
+
Main callable method that handles both single and batch inputs.
|
| 276 |
+
"""
|
| 277 |
+
# Handle defaults
|
| 278 |
+
if return_token_type_ids is None:
|
| 279 |
+
return_token_type_ids = True
|
| 280 |
+
if return_attention_mask is None:
|
| 281 |
+
return_attention_mask = True
|
| 282 |
+
|
| 283 |
+
if isinstance(text, list):
|
| 284 |
+
if text_pair is not None:
|
| 285 |
+
batch = [(t, p) for t, p in zip(text, text_pair)]
|
| 286 |
+
else:
|
| 287 |
+
batch = text
|
| 288 |
+
return self.batch_encode_plus(
|
| 289 |
+
batch,
|
| 290 |
+
add_special_tokens=add_special_tokens,
|
| 291 |
+
padding=padding,
|
| 292 |
+
truncation=truncation,
|
| 293 |
+
max_length=max_length,
|
| 294 |
+
stride=stride,
|
| 295 |
+
is_split_into_words=is_split_into_words,
|
| 296 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 297 |
+
return_tensors=return_tensors,
|
| 298 |
+
return_token_type_ids=return_token_type_ids,
|
| 299 |
+
return_attention_mask=return_attention_mask,
|
| 300 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 301 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 302 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 303 |
+
return_length=return_length,
|
| 304 |
+
verbose=verbose,
|
| 305 |
+
**kwargs
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
return self.encode_plus(
|
| 309 |
+
text=text,
|
| 310 |
+
text_pair=text_pair,
|
| 311 |
+
add_special_tokens=add_special_tokens,
|
| 312 |
+
padding=padding,
|
| 313 |
+
truncation=truncation,
|
| 314 |
+
max_length=max_length,
|
| 315 |
+
stride=stride,
|
| 316 |
+
is_split_into_words=is_split_into_words,
|
| 317 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 318 |
+
return_tensors=return_tensors,
|
| 319 |
+
return_token_type_ids=return_token_type_ids,
|
| 320 |
+
return_attention_mask=return_attention_mask,
|
| 321 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 322 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 323 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 324 |
+
return_length=return_length,
|
| 325 |
+
verbose=verbose,
|
| 326 |
+
**kwargs
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def encode_plus(
|
| 330 |
+
self,
|
| 331 |
+
text: str,
|
| 332 |
+
text_pair: Optional[str] = None,
|
| 333 |
+
add_special_tokens: bool = True,
|
| 334 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 335 |
+
truncation: Union[bool, str] = False,
|
| 336 |
+
max_length: Optional[int] = None,
|
| 337 |
+
stride: int = 0,
|
| 338 |
+
is_split_into_words: bool = False,
|
| 339 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 340 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 341 |
+
return_token_type_ids: Optional[bool] = True,
|
| 342 |
+
return_attention_mask: Optional[bool] = True,
|
| 343 |
+
return_overflowing_tokens: bool = False,
|
| 344 |
+
return_special_tokens_mask: bool = False,
|
| 345 |
+
return_offsets_mapping: bool = False,
|
| 346 |
+
return_length: bool = False,
|
| 347 |
+
verbose: bool = True,
|
| 348 |
+
**kwargs
|
| 349 |
+
) -> BatchEncoding:
|
| 350 |
+
if max_length is None:
|
| 351 |
+
max_length = self.model_max_length
|
| 352 |
+
|
| 353 |
+
ids_a = list(self._cached_encode_str(text.strip()))
|
| 354 |
+
|
| 355 |
+
if text_pair is not None:
|
| 356 |
+
ids_b = list(self._cached_encode_str(text_pair.strip()))
|
| 357 |
+
else:
|
| 358 |
+
ids_b = None
|
| 359 |
+
|
| 360 |
+
input_ids = []
|
| 361 |
+
token_type_ids = []
|
| 362 |
+
|
| 363 |
+
if add_special_tokens:
|
| 364 |
+
input_ids.append(self.bos_token_id)
|
| 365 |
+
token_type_ids.append(0)
|
| 366 |
+
if ids_b is not None:
|
| 367 |
+
input_ids.extend(ids_a)
|
| 368 |
+
token_type_ids.extend([0] * len(ids_a))
|
| 369 |
+
input_ids.append(self.eos_token_id)
|
| 370 |
+
token_type_ids.append(0)
|
| 371 |
+
|
| 372 |
+
input_ids.extend(ids_b)
|
| 373 |
+
token_type_ids.extend([1] * len(ids_b))
|
| 374 |
+
input_ids.append(self.eos_token_id)
|
| 375 |
+
token_type_ids.append(1)
|
| 376 |
+
else:
|
| 377 |
+
input_ids.extend(ids_a)
|
| 378 |
+
token_type_ids.extend([0] * len(ids_a))
|
| 379 |
+
input_ids.append(self.eos_token_id)
|
| 380 |
+
token_type_ids.append(0)
|
| 381 |
+
else:
|
| 382 |
+
input_ids = ids_a.copy()
|
| 383 |
+
token_type_ids = [0] * len(input_ids)
|
| 384 |
+
if ids_b is not None:
|
| 385 |
+
input_ids.extend(ids_b)
|
| 386 |
+
token_type_ids.extend([1] * len(ids_b))
|
| 387 |
+
|
| 388 |
+
# Handle truncation
|
| 389 |
+
if truncation and len(input_ids) > max_length:
|
| 390 |
+
input_ids = input_ids[:max_length]
|
| 391 |
+
token_type_ids = token_type_ids[:max_length]
|
| 392 |
+
|
| 393 |
+
# Handle padding
|
| 394 |
+
if padding == True or padding == "max_length":
|
| 395 |
+
pad_len = max_length - len(input_ids)
|
| 396 |
+
if pad_len > 0:
|
| 397 |
+
if self.padding_side == "right":
|
| 398 |
+
input_ids.extend([self.pad_token_id] * pad_len)
|
| 399 |
+
token_type_ids.extend([0] * pad_len)
|
| 400 |
+
else:
|
| 401 |
+
input_ids = [self.pad_token_id] * pad_len + input_ids
|
| 402 |
+
token_type_ids = [0] * pad_len + token_type_ids
|
| 403 |
+
|
| 404 |
+
attention_mask = [1 if tid != self.pad_token_id else 0 for tid in input_ids]
|
| 405 |
+
|
| 406 |
+
encoded_dict = {
|
| 407 |
+
"input_ids": input_ids,
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
if return_attention_mask:
|
| 411 |
+
encoded_dict["attention_mask"] = attention_mask
|
| 412 |
+
|
| 413 |
+
if return_token_type_ids:
|
| 414 |
+
encoded_dict["token_type_ids"] = token_type_ids
|
| 415 |
+
|
| 416 |
+
if return_special_tokens_mask:
|
| 417 |
+
special_tokens_mask = [
|
| 418 |
+
1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id} else 0
|
| 419 |
+
for tid in input_ids
|
| 420 |
+
]
|
| 421 |
+
encoded_dict["special_tokens_mask"] = special_tokens_mask
|
| 422 |
+
|
| 423 |
+
if return_length:
|
| 424 |
+
encoded_dict["length"] = len([tid for tid in input_ids if tid != self.pad_token_id])
|
| 425 |
+
|
| 426 |
+
if return_tensors == "pt":
|
| 427 |
+
output = {}
|
| 428 |
+
for k, v in encoded_dict.items():
|
| 429 |
+
tensor = torch.tensor(v, dtype=torch.long)
|
| 430 |
+
if tensor.ndim == 1:
|
| 431 |
+
tensor = tensor.unsqueeze(0)
|
| 432 |
+
output[k] = tensor
|
| 433 |
+
else:
|
| 434 |
+
output = encoded_dict
|
| 435 |
+
|
| 436 |
+
return BatchEncoding(output, tensor_type=return_tensors)
|
| 437 |
+
|
| 438 |
+
def batch_encode_plus(
|
| 439 |
+
self,
|
| 440 |
+
batch_text_or_text_pairs: List[Union[str, Tuple[str, str]]],
|
| 441 |
+
add_special_tokens: bool = True,
|
| 442 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 443 |
+
truncation: Union[bool, str] = False,
|
| 444 |
+
max_length: Optional[int] = None,
|
| 445 |
+
stride: int = 0,
|
| 446 |
+
is_split_into_words: bool = False,
|
| 447 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 448 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 449 |
+
return_token_type_ids: Optional[bool] = True,
|
| 450 |
+
return_attention_mask: Optional[bool] = True,
|
| 451 |
+
return_overflowing_tokens: bool = False,
|
| 452 |
+
return_special_tokens_mask: bool = False,
|
| 453 |
+
return_offsets_mapping: bool = False,
|
| 454 |
+
return_length: bool = False,
|
| 455 |
+
verbose: bool = True,
|
| 456 |
+
**kwargs
|
| 457 |
+
) -> BatchEncoding:
|
| 458 |
+
all_input_ids = []
|
| 459 |
+
all_attention_masks = []
|
| 460 |
+
all_token_type_ids = []
|
| 461 |
+
all_special_tokens_masks = []
|
| 462 |
+
all_lengths = []
|
| 463 |
+
|
| 464 |
+
for item in batch_text_or_text_pairs:
|
| 465 |
+
if isinstance(item, tuple):
|
| 466 |
+
text, text_pair = item
|
| 467 |
+
else:
|
| 468 |
+
text, text_pair = item, None
|
| 469 |
+
|
| 470 |
+
encoded = self.encode_plus(
|
| 471 |
+
text=text,
|
| 472 |
+
text_pair=text_pair,
|
| 473 |
+
add_special_tokens=add_special_tokens,
|
| 474 |
+
padding=False, # We'll handle batch padding later
|
| 475 |
+
truncation=truncation,
|
| 476 |
+
max_length=max_length,
|
| 477 |
+
stride=stride,
|
| 478 |
+
is_split_into_words=is_split_into_words,
|
| 479 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 480 |
+
return_tensors=None, # Don't convert to tensors yet
|
| 481 |
+
return_token_type_ids=return_token_type_ids,
|
| 482 |
+
return_attention_mask=return_attention_mask,
|
| 483 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 484 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 485 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 486 |
+
return_length=return_length,
|
| 487 |
+
verbose=verbose,
|
| 488 |
+
**kwargs
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
all_input_ids.append(encoded["input_ids"])
|
| 492 |
+
if "attention_mask" in encoded:
|
| 493 |
+
all_attention_masks.append(encoded["attention_mask"])
|
| 494 |
+
if "token_type_ids" in encoded:
|
| 495 |
+
all_token_type_ids.append(encoded["token_type_ids"])
|
| 496 |
+
if "special_tokens_mask" in encoded:
|
| 497 |
+
all_special_tokens_masks.append(encoded["special_tokens_mask"])
|
| 498 |
+
if "length" in encoded:
|
| 499 |
+
all_lengths.append(encoded["length"])
|
| 500 |
+
|
| 501 |
+
batched = {
|
| 502 |
+
"input_ids": all_input_ids,
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
if all_attention_masks:
|
| 506 |
+
batched["attention_mask"] = all_attention_masks
|
| 507 |
+
if all_token_type_ids:
|
| 508 |
+
batched["token_type_ids"] = all_token_type_ids
|
| 509 |
+
if all_special_tokens_masks:
|
| 510 |
+
batched["special_tokens_mask"] = all_special_tokens_masks
|
| 511 |
+
if all_lengths:
|
| 512 |
+
batched["length"] = all_lengths
|
| 513 |
+
|
| 514 |
+
# Handle batch padding
|
| 515 |
+
if padding == True or padding == "longest":
|
| 516 |
+
max_len = max(len(ids) for ids in all_input_ids)
|
| 517 |
+
for key in batched:
|
| 518 |
+
if key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
|
| 519 |
+
padded_seqs = []
|
| 520 |
+
for seq in batched[key]:
|
| 521 |
+
pad_len = max_len - len(seq)
|
| 522 |
+
if pad_len > 0:
|
| 523 |
+
if key == "input_ids":
|
| 524 |
+
padding_value = self.pad_token_id
|
| 525 |
+
else:
|
| 526 |
+
padding_value = 0
|
| 527 |
+
|
| 528 |
+
if self.padding_side == "right":
|
| 529 |
+
padded_seq = seq + [padding_value] * pad_len
|
| 530 |
+
else:
|
| 531 |
+
padded_seq = [padding_value] * pad_len + seq
|
| 532 |
+
else:
|
| 533 |
+
padded_seq = seq
|
| 534 |
+
padded_seqs.append(padded_seq)
|
| 535 |
+
batched[key] = padded_seqs
|
| 536 |
+
|
| 537 |
+
if return_tensors == "pt":
|
| 538 |
+
def to_tensor_list(lst):
|
| 539 |
+
return [torch.tensor(item, dtype=torch.long) for item in lst]
|
| 540 |
+
|
| 541 |
+
for key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
|
| 542 |
+
if key in batched:
|
| 543 |
+
batched[key] = torch.nn.utils.rnn.pad_sequence(
|
| 544 |
+
to_tensor_list(batched[key]),
|
| 545 |
+
batch_first=True,
|
| 546 |
+
padding_value=self.pad_token_id if key == "input_ids" else 0
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Handle non-sequence data
|
| 550 |
+
if "length" in batched:
|
| 551 |
+
batched["length"] = torch.tensor(batched["length"], dtype=torch.long)
|
| 552 |
+
|
| 553 |
+
return BatchEncoding(batched, tensor_type=return_tensors)
|
| 554 |
+
|
| 555 |
+
def pad(
|
| 556 |
+
self,
|
| 557 |
+
encoded_inputs,
|
| 558 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 559 |
+
max_length: Optional[int] = None,
|
| 560 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 561 |
+
return_attention_mask: Optional[bool] = None,
|
| 562 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 563 |
+
verbose: bool = True,
|
| 564 |
+
) -> BatchEncoding:
|
| 565 |
+
"""Pad encoded inputs."""
|
| 566 |
+
# This is a simplified version - full implementation would be more complex
|
| 567 |
+
return encoded_inputs
|
| 568 |
+
|
| 569 |
+
# Save/Load methods
|
| 570 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 571 |
+
"""Save vocabulary to files."""
|
| 572 |
+
if not os.path.isdir(save_directory):
|
| 573 |
+
os.makedirs(save_directory)
|
| 574 |
+
|
| 575 |
+
vocab_file = os.path.join(
|
| 576 |
+
save_directory,
|
| 577 |
+
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 581 |
+
json.dump(self.token_to_id, f, ensure_ascii=False, indent=2)
|
| 582 |
+
|
| 583 |
+
return (vocab_file,)
|
| 584 |
+
|
| 585 |
+
def save_pretrained(
|
| 586 |
+
self,
|
| 587 |
+
save_directory: Union[str, os.PathLike],
|
| 588 |
+
legacy_format: bool = True,
|
| 589 |
+
filename_prefix: Optional[str] = None,
|
| 590 |
+
push_to_hub: bool = False,
|
| 591 |
+
**kwargs
|
| 592 |
+
):
|
| 593 |
+
"""Save tokenizer to directory."""
|
| 594 |
+
if not os.path.exists(save_directory):
|
| 595 |
+
os.makedirs(save_directory)
|
| 596 |
+
|
| 597 |
+
# Save vocabulary
|
| 598 |
+
vocab_files = self.save_vocabulary(save_directory, filename_prefix)
|
| 599 |
+
|
| 600 |
+
# Save tokenizer config
|
| 601 |
+
tokenizer_config = {
|
| 602 |
+
"tokenizer_class": self.__class__.__name__,
|
| 603 |
+
"model_max_length": self.model_max_length,
|
| 604 |
+
"padding_side": self.padding_side,
|
| 605 |
+
"truncation_side": self.truncation_side,
|
| 606 |
+
"special_tokens": {
|
| 607 |
+
"bos_token": self.bos_token,
|
| 608 |
+
"eos_token": self.eos_token,
|
| 609 |
+
"pad_token": self.pad_token,
|
| 610 |
+
"unk_token": self.unk_token,
|
| 611 |
+
"mask_token": self.mask_token,
|
| 612 |
+
}
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
config_file = os.path.join(save_directory, "tokenizer_config.json")
|
| 616 |
+
with open(config_file, "w", encoding="utf-8") as f:
|
| 617 |
+
json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
|
| 618 |
+
|
| 619 |
+
print(f"✅ Tokenizer saved to: {save_directory}")
|
| 620 |
+
|
| 621 |
+
return (save_directory,)
|
| 622 |
+
|
| 623 |
+
@classmethod
|
| 624 |
+
def from_pretrained(
|
| 625 |
+
cls,
|
| 626 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 627 |
+
*init_inputs,
|
| 628 |
+
**kwargs
|
| 629 |
+
):
|
| 630 |
+
"""Load tokenizer from pretrained directory or hub."""
|
| 631 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 632 |
+
vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
|
| 633 |
+
config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
|
| 634 |
+
|
| 635 |
+
# Load config if available
|
| 636 |
+
config = {}
|
| 637 |
+
if os.path.exists(config_file):
|
| 638 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 639 |
+
config = json.load(f)
|
| 640 |
+
|
| 641 |
+
# Merge config with kwargs
|
| 642 |
+
merged_config = {**config, **kwargs}
|
| 643 |
+
|
| 644 |
+
return cls(vocab_file=vocab_file, **merged_config)
|
| 645 |
+
else:
|
| 646 |
+
raise NotImplementedError("Loading from HuggingFace Hub not implemented yet")
|
| 647 |
+
|
| 648 |
+
def get_special_tokens_mask(
|
| 649 |
+
self,
|
| 650 |
+
token_ids_0: List[int],
|
| 651 |
+
token_ids_1: Optional[List[int]] = None,
|
| 652 |
+
already_has_special_tokens: bool = False
|
| 653 |
+
) -> List[int]:
|
| 654 |
+
"""Get special tokens mask."""
|
| 655 |
+
if already_has_special_tokens:
|
| 656 |
+
return [
|
| 657 |
+
1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id}
|
| 658 |
+
else 0 for tid in token_ids_0
|
| 659 |
+
]
|
| 660 |
+
|
| 661 |
+
mask = [1] # BOS
|
| 662 |
+
mask.extend([0] * len(token_ids_0)) # Token sequence
|
| 663 |
+
mask.append(1) # EOS
|
| 664 |
+
|
| 665 |
+
if token_ids_1 is not None:
|
| 666 |
+
mask.extend([0] * len(token_ids_1)) # Second sequence
|
| 667 |
+
mask.append(1) # EOS
|
| 668 |
+
|
| 669 |
+
return mask
|
| 670 |
+
|
| 671 |
+
def create_token_type_ids_from_sequences(
|
| 672 |
+
self,
|
| 673 |
+
token_ids_0: List[int],
|
| 674 |
+
token_ids_1: Optional[List[int]] = None
|
| 675 |
+
) -> List[int]:
|
| 676 |
+
"""Create token type IDs for sequences."""
|
| 677 |
+
sep = [self.eos_token_id]
|
| 678 |
+
cls = [self.bos_token_id]
|
| 679 |
+
|
| 680 |
+
if token_ids_1 is None:
|
| 681 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 682 |
+
|
| 683 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 684 |
+
|
| 685 |
+
def build_inputs_with_special_tokens(
|
| 686 |
+
self,
|
| 687 |
+
token_ids_0: List[int],
|
| 688 |
+
token_ids_1: Optional[List[int]] = None
|
| 689 |
+
) -> List[int]:
|
| 690 |
+
"""Build inputs with special tokens."""
|
| 691 |
+
if token_ids_1 is None:
|
| 692 |
+
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
| 693 |
+
|
| 694 |
+
return ([self.bos_token_id] + token_ids_0 + [self.eos_token_id] +
|
| 695 |
+
token_ids_1 + [self.eos_token_id])
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
class FastChemTokenizerSelfies(FastChemTokenizer):
|
| 699 |
+
"""
|
| 700 |
+
SELFIES variant that handles whitespace-separated tokens.
|
| 701 |
+
Uses trie-based longest-match encoding (same as original working version).
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
def _encode_core(self, text: str) -> List[int]:
|
| 705 |
+
"""Trie-based encoding for SELFIES with fragment + atom vocab."""
|
| 706 |
+
result_ids = []
|
| 707 |
+
i = 0
|
| 708 |
+
n = len(text)
|
| 709 |
+
|
| 710 |
+
while i < n:
|
| 711 |
+
if text[i].isspace(): # skip literal whitespace
|
| 712 |
+
i += 1
|
| 713 |
+
continue
|
| 714 |
+
|
| 715 |
+
node = self.trie_root
|
| 716 |
+
j = i
|
| 717 |
+
last_match_id = None
|
| 718 |
+
last_match_end = i
|
| 719 |
+
|
| 720 |
+
# Traverse trie character by character (including spaces if part of vocab key)
|
| 721 |
+
while j < n and text[j] in node.children:
|
| 722 |
+
node = node.children[text[j]]
|
| 723 |
+
j += 1
|
| 724 |
+
if node.token_id is not None:
|
| 725 |
+
last_match_id = node.token_id
|
| 726 |
+
last_match_end = j
|
| 727 |
+
|
| 728 |
+
if last_match_id is not None:
|
| 729 |
+
result_ids.append(last_match_id)
|
| 730 |
+
i = last_match_end
|
| 731 |
+
else:
|
| 732 |
+
# Fallback: encode one char as unk or atom
|
| 733 |
+
result_ids.append(self.token_to_id.get(text[i], self.unk_token_id))
|
| 734 |
+
i += 1
|
| 735 |
+
|
| 736 |
+
return result_ids
|
| 737 |
+
|
| 738 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 739 |
+
"""SELFIES decoding: join tokens with spaces (preserve original format)."""
|
| 740 |
+
return " ".join(tokens)
|
| 741 |
+
|
| 742 |
+
def decode(
|
| 743 |
+
self,
|
| 744 |
+
token_ids: Union[List[int], torch.Tensor],
|
| 745 |
+
skip_special_tokens: bool = False,
|
| 746 |
+
clean_up_tokenization_spaces: bool = None,
|
| 747 |
+
**kwargs
|
| 748 |
+
) -> str:
|
| 749 |
+
if isinstance(token_ids, torch.Tensor):
|
| 750 |
+
token_ids = token_ids.tolist()
|
| 751 |
+
|
| 752 |
+
if skip_special_tokens:
|
| 753 |
+
special_ids = {
|
| 754 |
+
self.bos_token_id,
|
| 755 |
+
self.eos_token_id,
|
| 756 |
+
self.pad_token_id,
|
| 757 |
+
self.mask_token_id,
|
| 758 |
+
}
|
| 759 |
+
else:
|
| 760 |
+
special_ids = set()
|
| 761 |
+
|
| 762 |
+
tokens = []
|
| 763 |
+
for tid in token_ids:
|
| 764 |
+
if tid in special_ids:
|
| 765 |
+
continue
|
| 766 |
+
token = self.id_to_token.get(tid, self.unk_token)
|
| 767 |
+
tokens.append(token)
|
| 768 |
+
|
| 769 |
+
return " ".join(tokens) # ✅ preserve spaces
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 gbyuvd
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"training": {
|
| 3 |
+
"batch_size": 16,
|
| 4 |
+
"num_epochs": 1,
|
| 5 |
+
"learning_rate": 5e-5,
|
| 6 |
+
"weight_decay": 0.01,
|
| 7 |
+
"gradient_accumulation_steps": 4,
|
| 8 |
+
"tokenize_batch_size": 100,
|
| 9 |
+
"train_split_ratio": 0.8,
|
| 10 |
+
"val_split_ratio": 0.1,
|
| 11 |
+
"test_split_ratio": 0.1,
|
| 12 |
+
"include_for_metrics": ["input_ids", "attention_mask", "labels"]
|
| 13 |
+
},
|
| 14 |
+
"model": {
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"hidden_size": 320,
|
| 17 |
+
"num_hidden_layers": 6,
|
| 18 |
+
"num_attention_heads": 4,
|
| 19 |
+
"num_key_value_heads": 2,
|
| 20 |
+
"head_dim": 64,
|
| 21 |
+
"intermediate_size": 1280,
|
| 22 |
+
"sliding_window": 16,
|
| 23 |
+
"rope_theta": 10000.0,
|
| 24 |
+
"attention_dropout": 0.1
|
| 25 |
+
},
|
| 26 |
+
"generation": {
|
| 27 |
+
"max_length": 64,
|
| 28 |
+
"top_k": 50,
|
| 29 |
+
"top_p": 0.9,
|
| 30 |
+
"temperature": 1,
|
| 31 |
+
"do_sample": true,
|
| 32 |
+
"num_return_sequences": 3
|
| 33 |
+
}
|
| 34 |
+
}
|
demo_test_mtpresult.ipynb
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "4ff9650b",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"tensor([[ 0, 379, 1]])\n",
|
| 14 |
+
"tensor([[1, 1, 1]])\n",
|
| 15 |
+
"cuda:0\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"from FastChemTokenizerHF import FastChemTokenizerSelfies\n",
|
| 21 |
+
"# --- Load the tokenizer ---\n",
|
| 22 |
+
"tokenizer = FastChemTokenizerSelfies.from_pretrained(\"./selftok_core\")\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"# Test it\n",
|
| 25 |
+
"out = tokenizer(\"[C]\", return_tensors=\"pt\")\n",
|
| 26 |
+
"print(out.input_ids) # ← Attribute access works\n",
|
| 27 |
+
"print(out.attention_mask) # ← Also works\n",
|
| 28 |
+
"out = out.to(\"cuda\") # ← Moves all tensors to GPU\n",
|
| 29 |
+
"print(out.input_ids.device) # ← Should be cuda:0"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 2,
|
| 35 |
+
"id": "d16aeaf7",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [
|
| 38 |
+
{
|
| 39 |
+
"name": "stdout",
|
| 40 |
+
"output_type": "stream",
|
| 41 |
+
"text": [
|
| 42 |
+
"Model has 9,854,851 trainable parameters.\n",
|
| 43 |
+
"Input shape: torch.Size([2, 32])\n",
|
| 44 |
+
"Logits shape: torch.Size([2, 32, 782])\n"
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"source": [
|
| 49 |
+
"import torch\n",
|
| 50 |
+
"from ChemQ3MTP import ChemQ3MTP\n",
|
| 51 |
+
"# --- Initialize model from scratch ---\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"model = ChemQ3MTP.from_pretrained('./enhanced-qwen3-final')\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"# --- Print model parameter count ---\n",
|
| 56 |
+
"def count_parameters(model):\n",
|
| 57 |
+
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"print(f\"Model has {count_parameters(model):,} trainable parameters.\")\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"# --- Quick forward pass sanity check ---\n",
|
| 62 |
+
"batch_size, seq_len = 2, 32\n",
|
| 63 |
+
"dummy_input = torch.randint(\n",
|
| 64 |
+
" low=0,\n",
|
| 65 |
+
" high=len(tokenizer),\n",
|
| 66 |
+
" size=(batch_size, seq_len),\n",
|
| 67 |
+
" dtype=torch.long,\n",
|
| 68 |
+
")\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"with torch.no_grad():\n",
|
| 71 |
+
" outputs = model(dummy_input)\n",
|
| 72 |
+
" logits = outputs.logits\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"print(f\"Input shape: {dummy_input.shape}\")\n",
|
| 75 |
+
"print(f\"Logits shape: {logits.shape}\") # should be [batch_size, seq_len, vocab_size]\n"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": 3,
|
| 81 |
+
"id": "105b47a0",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [
|
| 84 |
+
{
|
| 85 |
+
"name": "stdout",
|
| 86 |
+
"output_type": "stream",
|
| 87 |
+
"text": [
|
| 88 |
+
"[Branch2] [=Branch1] [Branch1] [C] [=Branch1] [C] [=O] [N] [C] [C] [N] [C] [=Branch1] [C] [=O] [C] [N] [C] [=Branch1] [C] [=O] [NH1] [C] [=Ring2] [Ring1] [=Branch1] [=C] [Branch2] [Ring1] [C] [C] [C] [O] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [Branch1] [=Branch2] [N] [C] [C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [O] [C] [=C] [Ring1] [=C] [Ring1] [#Branch1] [C] [Branch2] [Ring1] [O] [C] [C] [O] [C] [=N]\n"
|
| 89 |
+
]
|
| 90 |
+
}
|
| 91 |
+
],
|
| 92 |
+
"source": [
|
| 93 |
+
"# Generate SELFIES\n",
|
| 94 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 95 |
+
"model.to(device)\n",
|
| 96 |
+
"input_ids = tokenizer(\"<s>\", return_tensors=\"pt\").input_ids.to(device)\n",
|
| 97 |
+
"gen = model.generate(input_ids, max_length=256, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n",
|
| 98 |
+
"print(tokenizer.decode(gen[0], skip_special_tokens=True))"
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "code",
|
| 103 |
+
"execution_count": 4,
|
| 104 |
+
"id": "b041d311",
|
| 105 |
+
"metadata": {},
|
| 106 |
+
"outputs": [
|
| 107 |
+
{
|
| 108 |
+
"name": "stdout",
|
| 109 |
+
"output_type": "stream",
|
| 110 |
+
"text": [
|
| 111 |
+
"C1(=O)NCCNC(=O)CNC(=O)[NH1]C1C(CCOS(=O)(=O)C=C2C=CCNCCN(C)C)(O)OC=C2\n"
|
| 112 |
+
]
|
| 113 |
+
}
|
| 114 |
+
],
|
| 115 |
+
"source": [
|
| 116 |
+
"# Manually convert it to SMILES\n",
|
| 117 |
+
"import selfies as sf\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"test = tokenizer.decode(gen[0], skip_special_tokens=True)\n",
|
| 120 |
+
"test = test.replace(' ', '')\n",
|
| 121 |
+
"print(sf.decoder(test))\n"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": 11,
|
| 127 |
+
"id": "f1608fa0",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [
|
| 130 |
+
{
|
| 131 |
+
"name": "stdout",
|
| 132 |
+
"output_type": "stream",
|
| 133 |
+
"text": [
|
| 134 |
+
"C=1=NC2=CC=CC=C2N=1\n"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"data": {
|
| 139 |
+
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAEsASwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigBskiRRtJI6oijLMxwAPc02G4huY/MgmjlTON0bBh+YrnfiJ/wAk71//AK83/lXk/wAK9ZuvB2oabYak/wDxJ/EMfm2sh4WOcEoR+OAD9UNerhcseJwk68Je9F6Lvpd287a28iJTtKx71PcQWsfmXE0cSZxukYKM/U08EMAQQQeQRXmfx2/5J6n/AF/RfyavQtL/AOQRZf8AXBP/AEEVzVMKoYWGIv8AE2relv8AMal7zRbooorjKCiiigAooooAKKKKACiiigAooooAKKKiurq3sraS5up44IIl3PLKwVVHqSeBQBLWVr/iXR/C+nm91m/itYei7zlnPoqjlj7CuOn8e6v4qnksfh/pwuI1YpLrV6pS1jPfYOsh+n5EVpeH/hxYadqA1nW7qbXteOCb29GRGfSJOiD07jtigDBk8T/EnxG7aj4X8PWthpUI3xJqx2zXw9AuRsBGCM4H+0elb3hv4jafq94NI1a2l0PXxw2n3vylz6xscBwe2OfbvXaVj+IfC+jeKrD7HrNhFcxjlGPDxn1Vhyp+lAGxRXmv2Xxr8P8AmyebxX4fT/l3lYC+t1/2W/5agenXsMV1nhnxlofi23aTSrwNLHxNbSDZNCe4ZDyOeM9PegDeooooAKKKKACiiigAooooAKKKKACiiigAooooA5n4if8AJO9f/wCvN/5VxWneEU8YfArSbNAFvoYWmtJOmJA7cZ9D0/I9q7rx3bT3ngTWra1hknnktHVI41LMxx0AHU1X+G9nc2Hw90e1vLeW3uI4mDxSoVZTvbqDyK9nD4mVDAKdN2kqia/8Bf4GbV52fY8k8T+LZPE3waWC/JXV9Ov4re8R+GJCuA5Hvg59wa9jh8T6Hpi6dpl/qlra3klpHIkcz7NykYByeOoPGc15X8YPh/qLav8A2zoFlc3Md+QLy3toy5Eg5D7R2Pr6j/ar146Hp2raJaWuradb3KrAgKXEQYqdo9ehrtzGeDlhqMo/DJydla8W7aeid7baEwUuZmsrK6hkYMpGQQcgilrhW+GkOnM0vhbXNS0KTORFHKZrcn3jfr+dJ/aXxB0Ef6fpNj4gtl/5bWEnkzY9SjcE+y15H1OnU/gVE/KXuv8AH3f/ACY05mt0d3RXHaf8TfDd1cC1vp59Hve9tqkRgYfifl/WuuiljniWWGRJI2GVdDkEexFc1bDVqDtVi16jTT2H0UUViMKKKKACiiigAorm/E/jnRPCgSK9naa/l4gsLVfMuJiegCD19TgVzP8AY3jHx98/iCd/Dmgv00uzfNzOv/TWT+EH+6PoR3oA0tb+JFpb6i2i+G7OXxBrg4NvaH91D2zLL91QP/14qnbfD7UfElzHqPxA1EagynfFpFqSlnAfcdZD7n6ciuy0Pw/pPhvT1sNHsIbO3X+GMcsfVj1Y+5JNVPEvjDRPCVqs2rXqxu/EVug3zTH0RByee/T3oA2YIIbaBIIIkihjUKkcahVUDsAOgrkvEfxF0zRr7+yNPgm1rXm4TTrEbmU/9NG6IPXPPtWN5Xjb4gcztN4T8Pv/AMskP+n3C+56RA/n9RXY+HPCui+FLH7Jo1jHbq3MknWSU+rMeSaAMXw3pXjC51hdc8T6sluAjLFo1jjyYwe8jHl2Ht0PQ4OK7OiigArlPE3w/wBH8R3K6ghm03WY+YtTsW8uZT/tEfeHse3cV1dFAHmy+LPE3gdhB41s/wC0dKBwuu6fETtHrPEOV+o4+td9pmqWGs2Ed9pt3Dd2sgyssLhgfb2PtVplDKVYAgjBB71wWp/Dc2V/JrHgnUDoGpv80kCjdaXJ9Hj6D6r09M0Ad9RXAad8R20++j0jxxp50LUXO2O5J3WdwfVJP4fo3T17V3ysrqGVgykZBByCKAFooooAKKKKACiiigAooooAKa7rGjO7BUUEszHAA9TTq4DxZeXHizXR4I0qVkgCiTWbqP8A5ZRHpED/AHn/AJevNdGGw7rzteyWrfZd/wDLu9BSdkWfDGran4s8SXWuQ3EkHhu3VrWzixgXjZ+aU57AjA//AFg9tUFnZ2+n2UNnaRLDbwII441HCqBgCuI1K+u/GPjFNC0q5mg0nSZVm1O7hcqZJQcrArD3GW/+tzu4rFVW4LlhFfcl37tv72+wvhXmd9RRRXAUFFFFAFTUNLsNWtzb6jZW93D/AHJ4w4/WuQufhhZW2+Twzq+peHp2bfi0mLxE8/ejbII56ZFd1RWsa9WMHTUnyvdX0+4Vle5559s+Jvh0f6TYaZ4otV/5aWr/AGW5I9Sp+Q/Ras2PxZ8OSXC2esC80C+P/LDVoDD+T/dx7kiu6qtfafZanbNbX9pBdQN1injDqfwPFZDJLa6t7y3S4tZ4p4XGVkicMrD2I4NS1wNz8JtGgme68OXuo+HLtjktp1wwjY/7UZJBHsMVi65q/wASPBcdhbTX+iazHqV7Hp9rdzQNDKksmdpdE+UrwenNAHpGs67pfh7T3v8AV76Gztk6vK2Mn0A6k+w5rhf7e8XePf3fhm2bQNDfg6vex5nmX1hi7D0Y/hg1o6N8N7ddQTWvFV7J4h1ocrJcj9xB7RRfdA9/x4ruCQqlmICgZJPagDmvDHgTRPCrPcW0UlzqUvM+o3bebcSk9csen0GK3NR1Kx0ixkvdRu4bW1jGXlmcKo/E/wAq4vU/iR9tvpNH8E6edf1NflkmQ4tLc+ry9D9B19c0mnfDh9Rvo9X8c6gdd1BDujtcbbO2Pokf8X1br3HegCs3i/xJ43c2/giy+xaYTtfXtQjIUj1hiPLn3PHqBW34a+Huk+H7ptSnebVdbk5l1O+bzJSf9nPCD2HbjJrrFVUUKqhVUYAAwAKWgAooooAKKKKACiiigAooooAq6jpljq9jJZajaQ3VrIMPFMgZT+B7+9cC3hDxJ4Ic3Hgi+N7pgOW0HUJCVA9IZTyh9jx6k16RRQByXhr4haT4hum02dZtK1uPiXTL5fLlB/2c8OPcducCutrC8S+D9D8W2qw6tZLI6cxXCHZNCfVHHI57dPauP0i68U+DfHej+FtV1WPW9I1cXH2K5nBF1B5Ue8hyOG6gZOSevGMUAem0UUUAFFFFABRRVHWNXstB0i61TUZhDaW0Zkkc+noPUk8AdyaAMfxp4mk0DTorfT4vtGtag/kWFuP4nPVj/sr1P4VP4Q8Mx+GNG+ztKbi+ncz3t03LTTNyxz6dh/8ArrD8Eadd61eyeONct2hvb2PZYWrnP2S27f8AAm6n6+5FdP4h16z8NaJcanek+XEvyxr96R+yqPU//X7V3ymvZxw1DXms2+76L0X4u77E21uzE8b+IbuyW20DQ8Pr+qEpB/07x/xTN6ADOPf1xitjwz4etPC+hwaZaZYJ80srfemkP3nb3J/oO1YvgjQLyJrnxLryg67qmGdD0tYf4YV9MDGff6ZrsaMVONOH1ak7pfE+8v8AJbL5vqEVf3mFFFFcBQUUUUAFFFFABRRRQAV5/wDFT7ngz/sabH/2evQK8/8Aip9zwZ/2NNj/AOz0AegVxviLwTeeLNZYatrlwPDqqu3SrUeV5rfxea4OWXPYY/DHPZUUAU9M0qw0WwjsdMs4bS1jGFihQKB7+59+tXKKKACiiigAooooAKKKKACiiigAooooAKhvJ/stlPcbd3lRs+3OM4GcVNVPV/8AkC33/XvJ/wCgmgCh4R18+KPCmna21sLY3ke/yQ+/ZyRjOBnp6VzXi3/kr3w6/wC4n/6TrVz4S/8AJK/D/wD17n/0Nqp+Lf8Akr3w6/7if/pOtAHoFFFFABRRRQAV5hJ/xdHxj5Q+fwhoU/zn+HULsdvdE/I+4PGh481q+1K/t/A3h6bZqmoJuvLlefsVr0Zz6M3QD37ZBrr9D0Wx8O6La6Tp0IitbZAiL3PqT6knJJ9TQBed0jjZ3ZURRksTgADvXmmjo3xL8Xr4iuFJ8MaPKU0qJhxdzjhpyO6g8L/TkVN4zvrnxfr6+AdHmeOHaJdcvI/+WEB6RA/33/l6jOO/sbG20ywgsbKFYba3QRxRoOFUDAFAFiiiigAooooAKKKKACiiigAooooAK8/+Kn3PBn/Y02P/ALPXoFef/FT7ngz/ALGmx/8AZ6APQKKKKACiiigAooooAKKKKACiiigAooooAKKKKACqer/8gW+/695P/QTVyqer/wDIFvv+veT/ANBNAHL/AAl/5JX4f/69z/6G1U/Fv/JXvh1/3E//AEnWrnwl/wCSV+H/APr3P/obVT8W/wDJXvh1/wBxP/0nWgD0CiiigArD8Ya+/hbwlqOtJaPdvaRb1hTuSQMn0UZyT6A1uU10SSNkdQyMCGVhkEHsaAOK+GWiLa+Hhr91cre6vroW8vLsc53DKxr6KoOMeufoNLx/4jn8KeCtR1e1hEtxEqrHu+6rMwUM3sM5/CuX0Z3+Gvi5fDlwxHhjV5S+lSseLWc8tbk9gTyv9eSPR7u0gvrOa0uolmt5kMckbjIZSMEGgDnvAnhmDw14eRVuRe3t6ftV7fZ3G5lfktnuvPHt9TXT15t4Ru5/A/iT/hBNUld9Pn3S6FdyHO6Pq0DH+8vb2+oFek0AFFFFABRRRQAUUUUAFFFFABRRRQAV5/8AFT7ngz/sabH/ANnr0CvP/ip9zwZ/2NNj/wCz0AegUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFU9X/5At9/17yf+gmrlU9X/AOQLff8AXvJ/6CaAOX+Ev/JK/D//AF7n/wBDaqfi3/kr3w6/7if/AKTrVz4S/wDJK/D/AP17n/0Nqp+Lf+SvfDr/ALif/pOtAHoFFFFABRRRQBkeJ/Dlj4r8P3OkX6nypl+SRfvROPuuvuD/AId65/wD4jvp2u/C3iFgPEOk4WR+13D/AATL65GM+/1wO3rivH3hq9vVtfEfh/CeI9IzJb+lzH/HA3qCM49/TOaANTxn4Vg8XaA9i0ht7uJhPZXS8NbzLyrg/wA/aqXgPxVPr1hcafq0Yt/EOlv5GoW/TLdpF/2WHI/wxWr4V8SWfizw9batZZVZRtkib70Mg4ZG9wf8e9c5430K7stUtfG2gof7T09Cl5Ag/wCP216shHdh1B/ngCqhBzkordgd5RVHRtXs9e0i21OwlEltcJvQ9x6g+hByCPar1EoyhJxkrNAFFFFSAUUUUAFFFFABRRRQAV5/8VPueDP+xpsf/Z69Arz/AOKn3PBn/Y02P/s9AHoFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABVPV/wDkC33/AF7yf+gmrlU9X/5At9/17yf+gmgDl/hL/wAkr8P/APXuf/Q2qn4t/wCSvfDr/uJ/+k61c+Ev/JK/D/8A17n/ANDaqfi3/kr3w6/7if8A6TrQB6BRRRQAUUUUAFFFFAHnmuJJ4B8UHxNbKx0HUXWPVoVGRBIeFnA9Ozf1yMegxyJLGkkbq8bgMrKcgg9CDUd1awX1pNa3USywTIY5I2GQykYINcL4Wup/B+vnwVqcrPZyhpdFupD9+PqYSf7y9vb04Fei/wDa6N1/Egvviv1j/wCk+hHwvyIZP+LceLPOHy+Ftam/eD+Gxuj39kb9PbHPo1U9V0u01rSrnTb+IS2twhSRT6eo9COoPqK5HwXql3pGpS+CtclL3lom/T7l/wDl7tu3/Al6Ee3sTRU/2ul7VfHFa+a/m9Vs/Kz7gvdduh3VFFFecWFFFFABRRRQAUUUUAFef/FT7ngz/sabH/2evQK8v+LGvaRHceE7R9Us1uLfxHZ3E0RmXdFGu7c7DPyqMjk0AeoUxZY2kaNXUumCyg8rnpkUsciSxrJG6ujgMrKcgg9CDXIeJfAFvq+p/wBu6TqFzo3iBVCi+tmyJABgLIh4daAOxorzm38e6t4WnjsPiBpwtkYhItas1L2kp7bx1jJ9/wBBXoNtc297bR3NrPHPBINySxOGVh6gjg0AS0UUUAFFFFABRRRQAUUUUAFFFFABVPV/+QLff9e8n/oJq5RQBxfwl/5JX4f/AOvc/wDobVT8W/8AJXvh1/3E/wD0nWu/ACgBQAB2Fec+KL21n+NXgG0iuYpLm3GoGaJXBaMNbjbuHUZwcZ9KAPR6KKKACiiigAooooAKwfF3hqHxRojWhkMF3Ewms7leGgmX7rA/z9q3qK0pVZ0pqpB2aE1dWZy/grxNNrljPZanGINc05/Iv4OnzdnX/ZYcirviDwxaeIHsJ5ZZra7sLhZ7e5gIDpz8y8jow4IqvrVno+iXtx40uIZxc2dmySmBsebHweVyAxGOMn+QwvhHxppPjWwmutLMy+TJskinUK68ZBIBPB5wc9jXbNTu8XhotRW/ZN7r07X6OxK/lkdFRWH4r8VWHg7Rxqeox3EkBlWLFuoZskEjgkccetbEEy3FvFOgIWRA4B64IzXE6U1BVGvdeifpuVdXsSUUdBk1y+r/ABD8L6NJ5M2qRT3WcC2tP30hPphc4P1xTpUKtaXLSi5PyVwbS3Ooorg/+Em8Z67xoPhYadA3S71qTYf+/S/NSSfD3UddiI8WeKtQvUYgtaWJ+ywY/ukLyw574Nb1MG6UW6k4p9r3f4XS+bQlK+xq678RPCnh1jFf6zbm5Bx9mgPnSk+m1ckfjisT/hM/GHiAY8MeDZbaFvu3uuv5C/Xyh85HuDXUaH4Q8PeGkC6Po9paMBjzEjzIR7ucsfxNbdcZR55/wgHiHXlz4t8ZXssTdbHSlFrD/ulh8zj64rb074c+D9KspLS28PWBjkXbIZohK7j3Z8n9a6iigDzeTwVr/g2RrrwHf+bY53PoOoSFoj6+U55Q+xOM9T2rY8OfETS9bvTpN9DNo2upw+nXw2OT/sN0cemOe+K7CsXxH4T0TxZZfZdYsY5wv+rlHyyRH1RxyP8AOaANa4t4bu3kt7mGOaGRdrxyKGVh6EHgivP7n4f6l4auZNR8AakLHcS8uj3ZL2cx77e8ZPqPpwKh2eNvh/8AcM3izw8n8Lf8f9uvsekoH5/QV13hrxfoni20M+kXqSsn+tgb5ZYj6Oh5H8vQ0AYmhfEizutRXRfEVnL4f13oLW7P7ub3ik+6w/yM129Zmu+HtJ8S6c1hrFhDd256LIOVPqp6qfcVxH9keMvAHzaFPJ4l0FOum3b4uoF9IpP4gP7p+gHegD0qiuc8L+ONE8WI6WFw0V7FxPY3K+XPCR1DIf5jIro6ACiiigAooooAKKKjuLiG0t5Li5mjhhjXc8kjBVUepJ4AoAkrM13xFpHhnT2vtYv4bS3HQyHlj6Ko5Y+wFcbc+P8AU/EtzJp/w/04Xu1tkusXYKWcJ77e8hHoP1FX9C+G9naaguteIbyXxBrvX7VdgeXCfSKP7qD/ACMUAZf9r+MvH/y6FBJ4a0F+upXaZup19Yo/4Qf7x+oPauk8LeAtA8ImSbT7Zpb+XJmv7lvMnlJ5JLHpnuBgGumooAKKKKACiiigAooooAKKKKAOZ+In/JO9f/683/lXjnh4z/D+08M+MrcO2k6lF9m1ONecHc2Gx9ACPdSP4q9j+In/ACTvX/8Arzf+VYvgnRrTxD8GNO0q9TdBc2zofVTvbDD3BwR9K+jy/Exw+Xt1FeEp2kvJx/NbrzRjON56djP+N80Vz8NoJ4ZFkiku4XR1OQylWIIrobyXxe9rp1p4dtdNS3a0jaS+vZGOxscqI15Jxg5PHNeJeIdVvNM8D33gTWWP27Sr+NrZz/y0hIbp7DII9m9q+kdL/wCQRZf9cE/9BFVj6LwOFpRaUlzTtfZpqNn9wRfNJnH/APCu7jVvm8V+JdR1YHrawt9mt/oUTk/XIrp9H8OaNoEXl6VpltaDGC0cYDN9W6n8TWpRXiVcbXqx5ZS93stF9ysjRRSCiiiuUoKKKKACiiigAooooAK5HxL8PNJ1+7Gp2zzaTracx6lYtskz/tgcOPrzjjIrrqKAPN08Y+I/BLrb+ObH7XpwIVNe0+MlAP8AptGOUPuOOwB6132n6jZatYx3un3UN1ayjKSwuGU/iKssqujI6hlYYIIyCK4DUPhxJpl7Jq3gXUf7Dv3O6W0I3WdyfRo/4fqvTsO9AGz4o8B6L4pdLm4jktNTi5g1Gzbyp4iOnzDqPY5/Cuc/4SHxb4C/d+KbZtd0ReBrFjHiaJfWaIf+hD9TV3S/iQLS/j0fxpp7eH9UY4jlkbdaXHvHL0H0PTpnNd4CGUEEEEcEd6AKOj63pniDT0v9Jvoby2fo8TZwfQjqD7Hmr9c5Z+BtC03xQfEGnW72V26Ms0dtIUhmz3eMcEjn8TnrXR0AFFc74o8baH4SiT+0bktdS8QWUC+ZPMewVB/M4HvXMf2Z4z8ffNrE0vhjQG6WNq+bydf+mj/wA+g56gjvQBqa98R7Gy1A6NoNpLr+u9PslmcrEfWWT7qD17jvis+38A6r4ouI7/4gaiLpFO+LRbNilpEe249ZCPf9RXY6D4b0fwxp62OjWENpAPvbB8zn1Zjyx9ya1KAIra2gs7aO3tYY4IIxtSONQqqPQAcCpaKKACiiigAooooAKKKKACiiigAooooAyvE2kvr3hnUdKjlWJ7uBoldhkKT3NQ+EdDk8N+FdP0eWZZpLVCpkQYDZYnp+NbdFbe3n7H2F/dvf52sKyvc89+JHwyj8cS2l5a3MdnfwgxvI6EiSPqAcdwen1Nd5aQm2soICQxijVCR3wMVNRV1cXWq0YUZu8YXt5XBRSdwooormGFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAU9U0nT9asJLHU7OG7tZPvRTIGH19j79a4I+FfFHgUmbwZeHVNJXltD1CXlB6QSn7v0PH1Nek0UAcJp3xd8KXFtdf2ndvo19Zj/SbLUEKSoR1Cj+P6Lk+wrPXxN4r8fAJ4RtTo2ivwdav48ySDv5MX/sx4+hrtNV8K6Brl5b3eqaPZ3lxb/6uSaEMR7c9R7HitcAAAAYA6CgDl/DHgLRvDEr3kSS3urS8z6leN5k8h7/ADH7o9h+Oa6iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAP/2Q==",
|
| 140 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEsCAIAAAD2HxkiAAAi4klEQVR4nO3deVxU9foH8IdhBxdQLA3JpVTE5WK4Y3nNel2X6WWloqajaTZR1hj3mlhpk+XCNavRNMVsGXfJROcm6g/NBTUlTTMUlxAEKU0EjU2Wmef3x3caCY0Zzpwz33OG5/3yD4HvnPOwfGbOnPN8v8cDEYEQwo+KdwGENHQUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFELCTWpq6ksvvbRr1y7ehXDmgYi8ayANztWrV0eOHHnkyBH2Ybdu3ZKTkx966CG+VfFCr4TEpSwWyxtvvBEaGsoSGBgYCAA///xzx44d4+LizGYz7wI5oBAS10lPT4+Ojl68eLHFYgkMDPzss89KSkq2bt3arFkzi8ViMBi6deu2e/du3mW6HBIivStXrmg0Gg8PDwAIDQ1dsGCB2WyuOWDZsmXt2rVjf5NqtfrSpUu8SnU9BbwnzMwEk8n6/wEDIDq6rsFJSZCdDQAwdCh07y55bcSu8vLypUuXzps3r6SkxN/fX6fTzZ49u1GjRnePrKysXLFixZw5c4qLi318fGJjY+fNm9e4cWPX1+xqvJ8F7NuwAQGs/1q2xJs36xo8ZIh15OrVrqqP/D2TydS2bVv2l6ZWq7Ozs+0+JD8/X6vVqlQqAHjggQcSExNrvWa6H4WFEAB1uroGUwhl4sSJE48++iiLX48ePQ4ePFivh6enp/fv3589vGfPnocPH5aoTjlQXghVKjx69G8HUwi5Kygo0Ol0np6eANC8eXODwVBdXS1gOxaLJSkp6cEHHwQADw+P0aNHX758WfRq5UBJIWzUyPqfnj3x736tFEKOKisrDQZD06ZNAcDb21un092s+82DA0pKSvR6vZ+fHwAEBgbq9fry8nJRqpUPJYXwqacwKsr6/08+ufdgCiEvqampERER7ADyiSeeOHPmjIgbz83N1Wg0bONhYWFGo1HEjXOnpBCOGIG7d1v/36QJXrlyj8EUQtc7d+7csGHDWEI6der07bffSrSj7777rvufp7wHDRr0008/SbQjF1NYCBFx2DDrhzEx9xhMIXSlwsLC+Ph4Hx8fAAgODk5ISKioqJB0j2az2Wg03nfffQCgUqk0Gs21a9ck3aMLKC+EFy+ir6/1M3c/51IIXYNvGGqGPygoyAXhl5TyQoiIM2ZYP9OmDZaU/GUwhdAFxDosrK6utlgsgsuoeRjcsWNH6Q6DpabIEN66ha1aWT/51lt/GUwhlJSIJ0j2798fGRnp/CkWSU8IuYYiQ4iI69ZZP+njgzV/7BRCiYh4qSAnJ2f06NEsNv369XO+NikujbiSUkOIiIMGWT8/cCDaDmoohKJjF83DwsKcv2heWlqq1+v9/f0BICAgID4+vri4WKw6xWoScD0Fh/D8+TtnaNavt36SQiiumu1jvXr1Etw+xpLcpk0bW5JzcnLELZX58ccfa7bLHThwQIq9iEvBIUTE+Hjrl0JDrWdoKIRiEbGR+vjx49F/Tn6JiopKS0sTt9S7CWgc50jZISwtxbZtrV/V6xHvFcIxY/CFFzAlBZV8EtulysrKEhIS2BwiHx8fnU73xx9/CNvUr7/+aktyq1atEhMTXXaIyL4LNmfK398/Pj5e8HchNWWHEBGTk61fDQjAvLzaISwsRC8v62eCgnDCBExOxrIy134DimIymUSZXMtOljRp0sR2suTWrVviluqIWpOJjUajMxdFJKL4ECKiWm0dMGnSPV4Js7LQYMDoaPTwsH7J3x/VajQakcdfhXydPHnyscceY/Hr3Lnzrl27BG/KZDLZVm1Sq9W//PKLiHUKcOzYsX79+tne2R45coRvPbW4QwizstDf3zrLqU2bv31PmJNTO41+fqhWY2Ii/v67xN+DvN24ccN2XrFZs2bOnFfMzMwcOnQo+3MPDw9PSUkRt1TBLBaL0Whs2bIlOzOk0Wh+++033kVZuUMIEfHdd/8y57DuEzO5uZiYiGr1nSNVT0+MjkaDAWXze3ERdtAYFBRkO2gsKioStimWZC8vL9ZHajAYqqqqRC1WBOxqp6+vr6wmRrlJCG/fxk6dHA2hzfXraDSiWo3e3rXTmJ8v3jcgV6mpqV26dLH1mmRkZAjbTlVVVWJiYosWLQDAy8tLq9X+Lu9Di4sXL9q6BR5++OGkpCS+9bhJCBHx//6v3iG0uXHDmkbbhUeVCqOiUK/Hixedq16Wzp8/r1arbV2X//vf/wRvas+ePd26dWObevzxx0+fPi1inZLau3evTCp3nxAi4qhRAkNoU1qKJhNqNHdm8QNgRATq9ZiZKWSDclNUVBQfH8+Ox9j8g9u3bwvblNxeTwSQyWu4AkK4ZQsGB2NwMD73nJ2ReXnYqpV18Nq1Tu20rMyaxiZNaqfx+HGntszL3ZOPrl69KmxTd7+zEpxkOSgsLOT7blYBIeSrvBxNJtRqsUWLO2ls3x51OkxLQ/ldc7q3ffv2/eMf/2CvWgMHDjx58qSw7cj5HKOTap7X7dSpkyvP61IIHVVRgTt34tSpGBLylzS+8QYeO2aW4SVgJi8vz3a1unXr1s5crT527Fjfvn3Zn2nv3r2///57cUuVAy5XOOUbwosXUa3GvDzeddyluhrT0lCnw9BQaxT79s0PCwvTarUmk0k+5+XZlAU2+SggIECv15cJ7RVSRN+JWFzf6yPTEFZUYM+eCIBaLe9S/p7ZjGlpGBeHw4evhD+1bNkyNjY2NTWVYxrvXrFT8JSFuzswRZx8JGes65U1MISEhEg6MUqmIWQLWLRvb2fRe/nIyMjQ6/Xh4eG2NAYHB2s0GpPJ5OKTFrWmLBw6dEjwppQ1F0EKJ06cGDBgAPsJPPLIIxLN/5BjCHfvRpUKvbxQZi1+DmFpjIqKsqUxICBArVYbjUapX0PunrIgePJRzUXsH3nkkfouYu9mTCYTmwnJnoxEnwkpuxBeu2ZdP2bhQt6lOCcrK8tgMERHR7O3UuxwjqVR9PcYFRUVtrcxbPKR4F0od366pEpLS22H5aKvCSCvEFos1ikRAwf+7UL3ipOTk1MrjX5+fmq1OjExUZRLwyaTqX379rbn6aysLGHbUfpKLS4g4qnmmuQVwg8+QABs0cI9Wzfz8vISExPVajW7LgwAnp6e0dHRBoNB2NW2zMzMIUOG2KYs7Ny5U3BtbrBmmcscPXq0T58+7GfVp08f5y/VyCiEx4+jjw96eOD27fYHl5bikiVKfbUsKCgwGo1qtZotX8taWFgar9xzcf+71JyywCYfCT4ZW2sR+x07dgjbToPC2o/uv/9+59uPUD4hLCmxToN4/XWHxr/0EgLglCkSlyWxwsJClkbWAsZ+o1FRUXq9/sKFC/d8CGt3DAkJsbU7Xr9+XfDeXbyIvZspLi62te81atRo9uzZwlbQkEsIJ05EAOzWDR2Z3vXNNwiAvr546pT0lblEaWmpyWTSaDQ17yMdERGh1+sza3SO79mzp2vXruyrgwcP/vnnn4Xtjj2Rs8Zlt7mjAy8XLlywNbJ7e3vHxcXVdwuyCOHmzQiAgYEOzVTIy8PmzREAly2TvjKXKykp+frrr8eOHVvzXu3du3ePi4t78skn2YcdOnQwmUyCd+Gu9zbiKyUlhTUnsYv7u3fvdvyx/EOYlWWdqfD55/YHm83WNX+HDVNM87Qw5eXlJpNJq9WyeQ8eHh7e3t5OTllw77v8cVdWVhYTE8POnXp6ejr+QM4hrKrCvn0RAEeNcmi8Xm9dZVTo+yDlqays3LVrF7twl5ubK2wjDeF+tzJx/vx59jTn+EM4h3DmTATAsDAsLLQ/+OBB9PRElQr37pW+Mplhp08EvAbevYi94CQTBykphPv2WdvTHFlbvbDQupLanDnSVyY/wkKYnp5ec6k/wYvYk3pRTAh//93anjZ/vkPjY2IQAHv3xspKiSuTpfqGMD8/39bb4eQi9qS+lBFCiwWfeqoe7WnLl1uX0G54ffxWjodQxEXsiTDKCOGHHyIABgejI/fYysiwru27caP0lcmVgyEUaxF74gwFhPD0afTzQw8P3LbN/uDycuzeXe6ze13AkRCazeZevXoBQGRkpCJuCaY4xcXF2dnZBQUFdQ+TewhLSjA8HAFQp3NofGysdZmz0lKJK5M3B18JDx8+7Mo7HzU0n3/+OQBMnjy57mH1DaEXuNarr8K5c9C1KyQk2B+cnAwrV4KvL6xfDwEB0henfP3797fd05MohcqVO/v6a/jqKwgIgKQk8Pe3M/jKFXjxRQCAxYshMlL64gjhxHUhvHTJGqolS6BzZzuDLRaYOBFu3IBhw2DaNBdURwg3LgphdTVMmAC3bsHIkTB1qv3x770H+/ZBaCgYjfDnfHRC3JOLQjhnDnz/PYSFwapV9gcfOgTz5oFKBUYjhIRIXxwhXLkihAcOwAcfgJcXbNwIzZrZGXzzJkyYAGYzvPUWDB7sguoI4UzyEF6/fn3x4mQvL9Dr4c/lMOsSGwuXL0Pv3vDOO1KXRogsSBtCRJw8efK33z779NNz33zT/viVK2HzZmjaFDZtAm9vSUsjRC6kDeGSJUt27NgRHBy8aNFkT087g8+ehf/8BwBgxQr4s/WKEPcnYQh//vnnN998EwC+/PJLdl+EOpSXV40dC2Vl8MILMG6cdEURIjtShbC0tDQmJub27duvvvrqiBEj7I6fMeP1gICEqKjqjz+WqCJCZEqqtjWdTnfu3LkuXbosWrTI7uBt27Z9+umnvr6+x44Nbdz4HxKVRIg8SfJKuGXLli+++MLPz2/Dhg3+9vrT8vPzp06dCgCLFi2y3U2WkIZD/BDm5eVptVoAWLJkiW1pvb9jsVgmTpx448aNoUOHvvbaa6IXQ4j8iRzC6urqsWPHFhUVPfvssyyKdZs3b9533313//33f/nllx7Un0YaJJFDqNfrjxw50rp1688++8zu4PT09Hnz5qlUqnXr1rFl/QlpgMQM4YEDB/773/+qVKq1a9c2s9efdvPmzTFjxlRVVc2aNeuJJ54QsQxClEW0EBYVFU2cONFsNuv1+n/+8592x7/88ss5OTm9evV69913xaqBECUSJ4SsPS03N/fRRx99++237Y5ftWrVpk2bmjZtunnzZm/qTyMNmzgh/OSTT7Zv3x4cHLx27VpPe/1pZ8+ejYuLA4BPP/20HfWnkQZPhBBmZGTMmjULAFauXNmmTZu6B1dUVIwfP76srGzKlCnPPfec83snROmcDSG7E015efkrr7wSExNjd/yMGTNOnTr18MMPGwwGJ3dNiHtwNoTTp0/PzMzs0qXL4sWL7Q5OSUlZvny5r69vUlJSzfvvEdKQORXCb775ZvXq1Y63p02aNAkRExISevTo4cx+CXEnwkNoa0/7+OOPHWlPmzRpUkFBwZAhQ6ZPny54p4S4H4EhrK6uHjduXGFh4TPPPBMbG2t3/IIFC/bu3UvtaYTcTWAI586de/jwYcfb09577z3WSdOyZUtheyTEXQkJ4cGDBxcuXKhSqdasWdO8efO6B9+6dWvs2LFVVVUzZ8588sknBRVJiDurdwiLioo0Go3ZbJ4zZ86gQYPsjn/55Zezs7N79uw5d+5cQRUS4ubqHULWnjZgwIA5c+bYHbx69eqNGzc2atRo/fr17L5ChJBa6hfC5cuXb9++PSgoaN26dXbb0y5evPjvf/8bAFauXNmxY0fhNRLi1uoRwm3btrGrCw62p40ZM6a4uPj5558fP368UzUS4tYcDWFZWZlWqzWbzZGRkWPGjLE7Pj09/ezZsw8++OAnn3ziXIWEuDmX3p+QEHI3R0MYEBCwatUqT0/PU6dObd682e743r17R0RE5Obm0vJNhNStHq+ETz/99JIlSwAgNjb28uXLdQ/29fXdvHlz48aNv/rqq/Xr1ztVIyFurX6Ho9OmTRsxYsTNmzcnTJhgNpvrHtyhQ4ePPvoIAGJjYy9cuCC8RkLcWr3fE7IbSxw6dOj999+3O3jq1Knjxo0rKSkZP358ZWWloAoJcXP1DqFtDYv3339/3759dsevWLGiXbt2x48f1+v1giokxM0JOTv62GOPvfnmm7bFs+se3LRp002bNnl7ey9atCg1NVVQkYS4M4GXKPR6fXR09JUrV1588UW7g3v37v3OO+9YLBaNRnP16lVheyTEXQkMoZeX18aNG5s1a5acnLxy5Uq74996663Bgwdfu3Zt8uTJiChsp4S4JeEX68PCwlatWgUAcXFxp0+ftrMblcpoNIaEhOzatYtd5yCEME51zIwcOXLq1Km3b99+7rnnysvL6x4cGhpqNBo9PDxmzZp18uRJZ/ZLiDtxtm1tyZIlnTt3PnPmzIwZM+wOHjZs2LRp0yoqKmJiYoqLi53cNSHuwdkQBgQEJCUl+fv7f/rpp0lJSXbHL168ODIy8pdffnn99ded3DUh7kGEBu6uXbsmJCSAw+1s69evDwgI+OKLLzZs2OD83glROnFmUbz22msjRoywrXxR9+CIiIiPP/4YAF555ZXs7GxRCiBEucQJoYeHB2tnS0tLmz9/vt3xWq127Nixt27dYrcoFKUGQhRKtPmEwcHBa9as8fT0nDt37v79++2OX7FiRdu2bX/44Qe6PyFp4MSc1Dtw4MD4+HjWGVNYWFj34KCgIHZzwoSEhD179ohYBiHKIvLM+rlz5/bv39/xdrbZs2dbLJYJEyZcu3ZN3EoIUQqRQ+jl5bVp06bg4OCtW7eyfpq6zZ49+/HHH6d2NtKQib/GjK2dbfr06Y60s7FlvHfu3ElLQpGGSZKFnkaNGjVlyhTH29lWr14NADNnzvzpp5+kqIcQOZNqtbWlS5eGh4efOXNm5syZdgc//fTTr7zySmTkuy+80IW62UhDI1UIAwMDk5KS/Pz8li1btn37drvjFy82lJXNOnHCKy5OoooIkSkJ1x3t1q3bwoUL4c/bV9Q92N/fe9MmCAiAzz+HjRulK4oQ2ZF28d/p06cPHz68qKho5swv7XWzQUQEfPghAMDLLwN1s5GGQ9oQsnY2tXrrtm36hQvtj4+NhTFj4NYtGDsWqJuNNBCSL4PfokWLGTOeqa6GuXPh8GH741euhDZtID0d3ntP6tIIkQVX3Iti4EB44w2oroZx48BeNxsEBcG6deDpCQsWwN69LqiOEM5cdEOY99+Hfv0gLw+0WvuDBwyA2bPBYoFJk6CgQPriCOHKRSH08oJ166BpU/jmG1i92v74d96BQYMgPx8mTQLqZiPuzXW3RmvfHj77DABg+nTIzLQzWKWCNWugeXNISYHly11QHSHcuPT+hKNHw/PPQ1kZxMSAvW42aN3aGtoZM+DUKemLI4QTV98kdNkyCA+HjAyYNcv+4GeegdhYqKiA8eOhrEz64pTvyJEjq1atsrvCCJEXdLnTp9HPDz08cNs2+4PLy7F7dwRArVb6ymTMx8cHAG7fvl3HGLPZ3LNnTwDo3Lnzrl27XFZbw1FcXJydnV1QUFD3sPomi0MIEfHDDxEAg4Px8mX7gzMy0N8fAXDjRukrkytHQoiIJpOpXbt27I9ArVZfunTJNeWRmpQRQosFn3oKAXDgQKyutj9++XIEwKAgzM6WvDZ5cjCEiFhRUWEwGBo3bgwAPj4+Op3ujz/+cEGFxEYZIUTE33/HVq0QAOfPd2h8TAwCYO/eWFkpcWWy5HgImfz8fK1Wq1KpAOCBBx5ITEw0m82SVkhsFBNCRNy3D1Uq9PLCw4ftDy4sxDZtEADnzJG+MvmpbwiZ9PT0/v37s7+Jnj17HnbkB02cpqQQIuLMmQiAYWFYWGh/8MGD6OmJKhXu3St9ZTIjLISIaLFYkpKSwsLCAMDDw2P06NGXHXkjTpygsBBWVWHfvgiAo0Y5NF6vRwAMDcXr1yWuTDYqKyt37drl6ekJALm5ucI2UlJSotfr/fz8ACAwMFCv15eXl4tbJ2HOnz+vsBAiYlYWNmmCAPj55/YHm804aBAC4LBhaLFIXxw/5eXlJpNJq9Xed9997EXM29ub5UfA6yGTm5ur0WjYn0hYWJjRaBS35gautLQ0JibGw8MDADw9PR1/IP8QIuLmzQiAgYGYmWl/cF4eNm+OALhsmfSVuVxJScnXX389duxYdnqT6d69e1xc3JNPPsk+7NChg8lkEryL7777rnv37mxTgwYN+umnn0Ssv8FKSUlhBxoAEBISsnv3bscfK4sQIuLEiQiA3bqhI0dJ33yDAOjri6dOSV+ZS5SWlppMJo1G06hRI1v2IiIi9Hp9Zo1npj179nTt2pV9dfDgwadPnxa2O7PZbDQa2WusSqXSaDTXrl0T6VtpcC5cuDB69Gj2S/H29o6Li6vvFuQSwpIS7NQJAfD11x0a/9JLCIBTpkhclsQKCwuNRqNarfb19WW/RZVKFRUVpdfrL1y4cM+HVFVVJSYmhoSEAICXl5dWq/39998F7z0+Pp6d8gkODk5ISKioqHDiu2lwiouL9Xo9+901atRo9uzZwi7JyiWEiHj8OPr4oIcHbt9uf3BpKS5Z4tCFfhkqKChg2WMBYNmLjo42GAxXrlxxZAs3btzQ6XReXl4A0KxZM4PBUFVVJayYc+fODRs2jJXRqVOnb7/9Vth2GhR2KHH//ffbDiWuXr0qeGsyCiEifvABAmCLFpifz7sUCeTl5SUmJqrVahYe9vadZe+3334TsMHMzMwhQ4awTYWHh6ekpAiuLTU1NSIigm3qiSeeOHPmjOBNub2jR4/26dOH/az69Olz9OhRJzcorxBaLKhW16OdTRFycnIMBkN0dDQ7bwYAfn5+arU6MTFR8JFkTSaTqX379rZ+0V9++UXYdiorKw0GQ9OmTdl7G51Od/PmTefLcyd5eXkajYb9Hlu3bm00Gi1inKOXVwgR8do1azvbwoW8S3FOVlZWrez5+/ur1Wqj0Xjr1i1x98X6RZs0aWLLj+BdFBQU6HQ6dlmyefPmBoOh2m2eDp1QWlqq1+v9/f0BICAgID4+vri4WKyNyy6EiLh7t7Wd7cgR3qXUX0ZGhl6vj4qKsp3kDAgIYNkT8dd2T7/++qutX7RVq1bO9IueOHHi0UcfZfX36NHj4MGD4paqIKzlqE2bNrZjjZycHHF3IccQIuKMGQiA7dujUg6IWPbCw8Nt2QsODtZoNCaTSfC1dWGOHz8eHR3NaoiKikpLSxO8KZPJ1LZtW9sfX3bDm8Ny/PjxAQMGsJ/AI4884swPsw4yDWFFBfbsKfe5vGYzpqVhXBwOH/6qLXstW7aMjY1NTU0VfLrSeTWfvFm/qOAn77KysoSEBHb10t/fX9zDMDljhxXssJwdVkh3WC7TECLixYuoVmNeHu867lJdjWlpqNNhaCgCIAD26bM1LCxMq9WaTCaO2aul1tsYvV5fVlYmbFNXrlyxnZAIDQ0V64SEPLETVKK8wXaQfEMoNxUVuHMnTp2KISHW7LED5jfewGPHzLL9oxTxhN6xY8f69u3LXvB79ep1RIlv2e0xmUwPPfSQ86ea64VCaEd5OZpMqNViixZ/yZ5Oh2lpimki379/f2RkJPvbGjhw4MmTJ4Vtx2KxGI3Gli1bsgNdjUYj7AqnDGVmZg4dOlSUi671pYAQbtmCwcHWf9Om2Rk8apR15Nq1Tu20rAxNJtRorDM82L+ICNTr8fhxp7bMi4hNHmxiFGvXcoOJUYWFhbb2o+DgYGfaj4RRQAg3bLgTA5XKzjT8IUOsI1evFrKv0lJr9ho1qp09R2Z4yF9RUVF8fDzLT1BQUEJCguCTtxcvXrQ1Lj/88MNJSUniluoCrBG3RYsWzjfiOkNhIQTArl3rWmZGWAhv3ECjEdVq9PW9k/aoKNTr8eJF578D2Tl//rxarWb56dChgzP52bt3b7du3dimHn/8ccETO1xPPpUrL4QAuGjR3w6uVwivX7dmz9vb+ihPT4yORoPBPZtXa0lNTe3SpYutXzQjI0PYdmTyeuI4ub2GKymEXbta0xIYiH933cuREObmYmIiqtXo5VU7e+5ylsFR7HR8UFAQOx2v1WqvC104hPs7K0fc/W7Wxa0U96SkEI4Yga++av3/0KH3HlxHCHNy0GDA6Gj08LCO8fNDtRoTE1HeT9ySYxOj2IVpNjFK8IVpjucY6ybn87oKC+G1a3dOVyYn32Pw3SHMyqqdPX9/VKvRaESJr8EqzNmzZ//1r3+x/Di5kD6Xq211qHmFs3fv3nK7wqmwEOKfcw7ZQol3d1DVCmFh4Z1jzqAgnDABk5NRaN9Ig1BrIf2srCxh23F938k9KaLXR3khrKiwLoQBgP/5T+3Bd78SjhmDL7yAKSlISzc4SMSF9F3ZgVmLgrpelRdCRNyxw/oZL6/aaz05eZ2Q2Ii4kL5r5iLUpKz5H4oMIaL1fjIA2KsX1vzboBCKS6yF9F0wK49R4kxIpYbw8mUMDLR+fsWKO5+nEIqO5efBBx8EpxfSLy0ttR0iij4/XblrAig1hIg4f77188HBaFs1k0IoEREX0s/JybFdK+/Xr5/ztSl9dRwFh7CiAsPDrV+yzf2lEEpKxIX02Zpla9ascbIkN1gnTsEhRMTU1Dutnj/8gEghdAmxFtKvrq525oJBrRVTd+zYIXhTfCk7hIg4Zoz1qwMHIlIIXYXvQvputna44kP422/YtKl1wNatFEKXqhkGNjFK6jCw8LNmcbe5i4biQ4iIH35oHdCxIw4eTCF0tXPnzg0fPpwdFnbs2FG6hfTd9X5S7hDCqirs3t06xjYpiULoYpKeIHHvOyu6QwgR8dChO/3ZFEJepLhU0BDuMewmIUTEyZMphLIg1kVz1iQQFhbmfJOAzLlPCAsK/rIYIYWQrx9//LFm+9iBAwfq9fD09PR+/fqxh/fq1Utwu5wiuE8IEXHFCgqhvAhopM7Pz7dNPnKycVwpPBAR5C0zE0wmAIAOHeDZZ+saabHA0qVQUQEAMHQo/HkijfBUXl6+dOnSefPmlZSU+Pv763S6t99+m82TuufI+fPnFxcX+/j4xMbGzps3754j3Q3vZwHSINSaXLtgwYJar2/Lli2rOZn40qVLvEp1PQohcZ1jx47Z3ukFBgauWrUKETdu3MhOqAJAZGTk/v37eZfpago4HCXuxGKxxMfHf/TRRxaLBQACAwNLS0sBQKVS6XS6xYsXs9OqDQqFkHBw9erVkSNHHjlyhH3Yo0eP5ORk25TfhoZCSLhJTU3dsmXLM888M2TIEN618EQhJIQzFe8CCGnoKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQziiEhHBGISSEMwohIZxRCAnhjEJICGcUQkI4oxASwhmFkBDOKISEcEYhJIQzCiEhnFEICeGMQkgIZxRCQjijEBLCGYWQEM4ohIRwRiEkhDMKISGcUQgJ4YxCSAhnFEJCOKMQEsIZhZAQzv4f6jkImFkQ6PQAAAAASUVORK5CYII=",
|
| 141 |
+
"text/plain": [
|
| 142 |
+
"<PIL.PngImagePlugin.PngImageFile image mode=RGB size=300x300>"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
"execution_count": 11,
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"output_type": "execute_result"
|
| 148 |
+
}
|
| 149 |
+
],
|
| 150 |
+
"source": [
|
| 151 |
+
"# Generate Mol Viz\n",
|
| 152 |
+
"from rdkit import Chem\n",
|
| 153 |
+
"from rdkit.Chem import Draw\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"input_ids = tokenizer(\"<s>\", return_tensors=\"pt\").input_ids.to(device)\n",
|
| 156 |
+
"gen = model.generate(input_ids, max_length=256, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n",
|
| 157 |
+
"generatedmol = tokenizer.decode(gen[0], skip_special_tokens=True)\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"test = generatedmol.replace(' ', '')\n",
|
| 160 |
+
"csmi_gen = sf.decoder(test)\n",
|
| 161 |
+
"print(csmi_gen)\n",
|
| 162 |
+
"mol = Chem.MolFromSmiles(csmi_gen)\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"# Draw the molecule\n",
|
| 165 |
+
"Draw.MolToImage(mol)"
|
| 166 |
+
]
|
| 167 |
+
}
|
| 168 |
+
],
|
| 169 |
+
"metadata": {
|
| 170 |
+
"kernelspec": {
|
| 171 |
+
"display_name": "base",
|
| 172 |
+
"language": "python",
|
| 173 |
+
"name": "python3"
|
| 174 |
+
},
|
| 175 |
+
"language_info": {
|
| 176 |
+
"codemirror_mode": {
|
| 177 |
+
"name": "ipython",
|
| 178 |
+
"version": 3
|
| 179 |
+
},
|
| 180 |
+
"file_extension": ".py",
|
| 181 |
+
"mimetype": "text/x-python",
|
| 182 |
+
"name": "python",
|
| 183 |
+
"nbconvert_exporter": "python",
|
| 184 |
+
"pygments_lexer": "ipython3",
|
| 185 |
+
"version": "3.13.0"
|
| 186 |
+
}
|
| 187 |
+
},
|
| 188 |
+
"nbformat": 4,
|
| 189 |
+
"nbformat_minor": 5
|
| 190 |
+
}
|
train-withmtp.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========================
|
| 2 |
+
# Train with NTP + MTP
|
| 3 |
+
# by gbyuvd
|
| 4 |
+
# ========================
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import math
|
| 12 |
+
from typing import List, Union, Optional, Tuple, Dict, Any
|
| 13 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 14 |
+
from transformers import Qwen3Config, Qwen3ForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
| 15 |
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel
|
| 16 |
+
from datasets import load_dataset, DatasetDict
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 19 |
+
from sklearn.model_selection import train_test_split
|
| 20 |
+
from ranger21 import Ranger21
|
| 21 |
+
from tqdm.notebook import tqdm
|
| 22 |
+
from FastChemTokenizerHF import FastChemTokenizerSelfies
|
| 23 |
+
from ChemQ3MTP import ChemQ3MTP
|
| 24 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 25 |
+
from transformers import TrainerCallback
|
| 26 |
+
import datetime
|
| 27 |
+
|
| 28 |
+
# ==============================
|
| 29 |
+
# Load external configuration
|
| 30 |
+
# ==============================
|
| 31 |
+
with open("config.json", "r") as f:
|
| 32 |
+
CONFIG = json.load(f)
|
| 33 |
+
|
| 34 |
+
TRAINING_CFG = CONFIG["training"]
|
| 35 |
+
MODEL_CFG = CONFIG["model"]
|
| 36 |
+
GENERATION_CFG = CONFIG.get("generation", {})
|
| 37 |
+
|
| 38 |
+
# Training params
|
| 39 |
+
BATCH_SIZE = TRAINING_CFG["batch_size"]
|
| 40 |
+
NUM_EPOCHS = TRAINING_CFG["num_epochs"]
|
| 41 |
+
LEARNING_RATE = TRAINING_CFG["learning_rate"]
|
| 42 |
+
WEIGHT_DECAY = TRAINING_CFG["weight_decay"]
|
| 43 |
+
GRAD_ACCUM_STEPS = TRAINING_CFG["gradient_accumulation_steps"]
|
| 44 |
+
TOKENIZE_BATCH_SIZE = TRAINING_CFG["tokenize_batch_size"]
|
| 45 |
+
TRAIN_SPLIT_RATIO = TRAINING_CFG["train_split_ratio"]
|
| 46 |
+
VAL_SPLIT_RATIO = TRAINING_CFG["val_split_ratio"]
|
| 47 |
+
TEST_SPLIT_RATIO = TRAINING_CFG["test_split_ratio"]
|
| 48 |
+
INCLUDE_FOR_METRICS = TRAINING_CFG.get("include_for_metrics", ["input_ids", "attention_mask", "labels"])
|
| 49 |
+
# ==============================
|
| 50 |
+
|
| 51 |
+
class LossLoggerCallback(TrainerCallback):
|
| 52 |
+
def __init__(self, log_file="training_losses.txt", with_timestamp=False):
|
| 53 |
+
self.log_file = log_file
|
| 54 |
+
self.with_timestamp = with_timestamp
|
| 55 |
+
with open(self.log_file, "w") as f:
|
| 56 |
+
if self.with_timestamp:
|
| 57 |
+
f.write("time\tstep\tloss\teval_loss\n")
|
| 58 |
+
else:
|
| 59 |
+
f.write("step\tloss\teval_loss\n")
|
| 60 |
+
|
| 61 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 62 |
+
if logs is None:
|
| 63 |
+
return
|
| 64 |
+
step = state.global_step
|
| 65 |
+
loss = logs.get("loss")
|
| 66 |
+
eval_loss = logs.get("eval_loss")
|
| 67 |
+
|
| 68 |
+
with open(self.log_file, "a") as f:
|
| 69 |
+
if self.with_timestamp:
|
| 70 |
+
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 71 |
+
f.write(f"{ts}\t{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
|
| 72 |
+
else:
|
| 73 |
+
f.write(f"{step}\t{loss if loss is not None else ''}\t{eval_loss if eval_loss is not None else ''}\n")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main():
|
| 77 |
+
# --- Load the tokenizer ---
|
| 78 |
+
tokenizer = FastChemTokenizerSelfies.from_pretrained("./selftok_core")
|
| 79 |
+
|
| 80 |
+
out = tokenizer("[C] [=C] [Branch1]", return_tensors="pt")
|
| 81 |
+
print(out.input_ids)
|
| 82 |
+
print(out.attention_mask)
|
| 83 |
+
out = out.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 84 |
+
print(out.input_ids.device)
|
| 85 |
+
|
| 86 |
+
# --- Define config ---
|
| 87 |
+
config = Qwen3Config(
|
| 88 |
+
vocab_size=len(tokenizer),
|
| 89 |
+
bos_token_id=tokenizer.bos_token_id,
|
| 90 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 91 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 92 |
+
tie_word_embeddings=True,
|
| 93 |
+
use_cache=False,
|
| 94 |
+
**MODEL_CFG
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
model = ChemQ3MTP(config, num_future_tokens=3)
|
| 98 |
+
|
| 99 |
+
def count_parameters(model):
|
| 100 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 101 |
+
|
| 102 |
+
print(f"Enhanced model has {count_parameters(model):,} trainable parameters.")
|
| 103 |
+
|
| 104 |
+
batch_size, seq_len = 2, 32
|
| 105 |
+
dummy_input = torch.randint(
|
| 106 |
+
low=0,
|
| 107 |
+
high=len(tokenizer),
|
| 108 |
+
size=(batch_size, seq_len),
|
| 109 |
+
dtype=torch.long,
|
| 110 |
+
)
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
outputs = model(dummy_input)
|
| 113 |
+
logits = outputs.logits
|
| 114 |
+
print(f"Input shape: {dummy_input.shape}")
|
| 115 |
+
print(f"Logits shape: {logits.shape}")
|
| 116 |
+
|
| 117 |
+
print("Loading dataset...")
|
| 118 |
+
dataset = load_dataset(
|
| 119 |
+
'csv',
|
| 120 |
+
data_files='./data/sample_all_14k.csv',
|
| 121 |
+
split='train',
|
| 122 |
+
streaming=True
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
print("Shuffling and splitting dataset...")
|
| 126 |
+
shuffled_dataset = dataset.shuffle(seed=42, buffer_size=10000)
|
| 127 |
+
|
| 128 |
+
total_lines = 14000
|
| 129 |
+
test_size = int(TEST_SPLIT_RATIO * total_lines)
|
| 130 |
+
val_size = int(VAL_SPLIT_RATIO * total_lines)
|
| 131 |
+
train_size = total_lines - test_size - val_size
|
| 132 |
+
|
| 133 |
+
test_dataset = shuffled_dataset.take(test_size)
|
| 134 |
+
remaining = shuffled_dataset.skip(test_size)
|
| 135 |
+
val_dataset = remaining.take(val_size)
|
| 136 |
+
train_dataset = remaining.skip(val_size)
|
| 137 |
+
|
| 138 |
+
print(f"Dataset split: train={train_size}, val={val_size}, test={test_size}")
|
| 139 |
+
|
| 140 |
+
def tokenize_function(examples):
|
| 141 |
+
batch_results = {"input_ids": [], "attention_mask": [], "labels": []}
|
| 142 |
+
smiles_list = examples['SELFIES'] if isinstance(examples['SELFIES'], list) else [examples['SELFIES']]
|
| 143 |
+
for smiles in smiles_list:
|
| 144 |
+
tokenized = tokenizer(
|
| 145 |
+
smiles,
|
| 146 |
+
truncation=True,
|
| 147 |
+
padding=False,
|
| 148 |
+
max_length=MODEL_CFG["max_position_embeddings"],
|
| 149 |
+
return_tensors=None,
|
| 150 |
+
add_special_tokens=True
|
| 151 |
+
)
|
| 152 |
+
input_ids = tokenized["input_ids"]
|
| 153 |
+
attention_mask = tokenized["attention_mask"]
|
| 154 |
+
labels = input_ids.copy()
|
| 155 |
+
batch_results["input_ids"].append(input_ids)
|
| 156 |
+
batch_results["attention_mask"].append(attention_mask)
|
| 157 |
+
batch_results["labels"].append(labels)
|
| 158 |
+
return batch_results
|
| 159 |
+
|
| 160 |
+
print("Tokenizing datasets...")
|
| 161 |
+
train_dataset = train_dataset.map(tokenize_function, batched=True, batch_size=TOKENIZE_BATCH_SIZE, remove_columns=["SELFIES"])
|
| 162 |
+
val_dataset = val_dataset.map(tokenize_function, batched=True, batch_size=TOKENIZE_BATCH_SIZE, remove_columns=["SELFIES"])
|
| 163 |
+
|
| 164 |
+
class EnhancedDataCollator:
|
| 165 |
+
def __init__(self, tokenizer, pad_to_multiple_of=8):
|
| 166 |
+
self.tokenizer = tokenizer
|
| 167 |
+
self.pad_to_multiple_of = pad_to_multiple_of
|
| 168 |
+
def __call__(self, features):
|
| 169 |
+
max_length = max(len(f["input_ids"]) for f in features)
|
| 170 |
+
if self.pad_to_multiple_of:
|
| 171 |
+
max_length = ((max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
|
| 172 |
+
batch = {"input_ids": [], "attention_mask": [], "labels": []}
|
| 173 |
+
for feature in features:
|
| 174 |
+
input_ids = feature["input_ids"]
|
| 175 |
+
attention_mask = feature["attention_mask"]
|
| 176 |
+
labels = feature["labels"]
|
| 177 |
+
padding_length = max_length - len(input_ids)
|
| 178 |
+
padded_input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
|
| 179 |
+
padded_attention_mask = attention_mask + [0] * padding_length
|
| 180 |
+
padded_labels = labels + [-100] * padding_length
|
| 181 |
+
batch["input_ids"].append(padded_input_ids)
|
| 182 |
+
batch["attention_mask"].append(padded_attention_mask)
|
| 183 |
+
batch["labels"].append(padded_labels)
|
| 184 |
+
batch = {key: torch.tensor(values, dtype=torch.long) for key, values in batch.items()}
|
| 185 |
+
return batch
|
| 186 |
+
|
| 187 |
+
data_collator = EnhancedDataCollator(tokenizer, pad_to_multiple_of=8)
|
| 188 |
+
|
| 189 |
+
def create_enhanced_optimizer(model_params):
|
| 190 |
+
num_batches_per_epoch = train_size // BATCH_SIZE
|
| 191 |
+
optimizer_params = {
|
| 192 |
+
'lr': LEARNING_RATE,
|
| 193 |
+
'weight_decay': WEIGHT_DECAY,
|
| 194 |
+
'use_adabelief': True,
|
| 195 |
+
'use_cheb': False,
|
| 196 |
+
'use_warmup': True,
|
| 197 |
+
'use_madgrad': True,
|
| 198 |
+
'num_epochs': NUM_EPOCHS,
|
| 199 |
+
'using_gc': True,
|
| 200 |
+
'warmdown_active': True,
|
| 201 |
+
'num_batches_per_epoch': num_batches_per_epoch
|
| 202 |
+
}
|
| 203 |
+
return Ranger21(model_params, **optimizer_params)
|
| 204 |
+
|
| 205 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 206 |
+
class EnhancedCustomTrainer(Trainer):
|
| 207 |
+
def create_optimizer(self):
|
| 208 |
+
self.optimizer = create_enhanced_optimizer(self.model.parameters())
|
| 209 |
+
return self.optimizer
|
| 210 |
+
def create_scheduler(self, num_training_steps, optimizer=None):
|
| 211 |
+
if optimizer is None:
|
| 212 |
+
optimizer = self.optimizer
|
| 213 |
+
self.lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
|
| 214 |
+
return self.lr_scheduler
|
| 215 |
+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
| 216 |
+
outputs = model(**inputs)
|
| 217 |
+
loss = outputs.loss
|
| 218 |
+
return (loss, outputs) if return_outputs else loss
|
| 219 |
+
|
| 220 |
+
steps_per_epoch = train_size // BATCH_SIZE
|
| 221 |
+
total_steps = steps_per_epoch * NUM_EPOCHS
|
| 222 |
+
|
| 223 |
+
training_args = TrainingArguments(
|
| 224 |
+
output_dir='./chemq3minipret',
|
| 225 |
+
max_steps=total_steps,
|
| 226 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 227 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
| 228 |
+
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
|
| 229 |
+
logging_dir='./gptlo-1',
|
| 230 |
+
logging_strategy="steps",
|
| 231 |
+
logging_steps=max(1, steps_per_epoch // 4),
|
| 232 |
+
eval_strategy="steps",
|
| 233 |
+
eval_steps=max(1, steps_per_epoch // 4),
|
| 234 |
+
save_strategy="steps",
|
| 235 |
+
save_steps=steps_per_epoch,
|
| 236 |
+
save_total_limit=1,
|
| 237 |
+
dataloader_num_workers=0,
|
| 238 |
+
dataloader_pin_memory=False,
|
| 239 |
+
remove_unused_columns=False,
|
| 240 |
+
prediction_loss_only=False,
|
| 241 |
+
fp16=torch.cuda.is_available(),
|
| 242 |
+
gradient_checkpointing=True,
|
| 243 |
+
dataloader_drop_last=True,
|
| 244 |
+
report_to=None,
|
| 245 |
+
include_for_metrics=INCLUDE_FOR_METRICS,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
print("Initializing enhanced trainer with MTP capabilities...")
|
| 249 |
+
trainer = EnhancedCustomTrainer(
|
| 250 |
+
model=model,
|
| 251 |
+
args=training_args,
|
| 252 |
+
train_dataset=train_dataset,
|
| 253 |
+
eval_dataset=val_dataset,
|
| 254 |
+
data_collator=data_collator,
|
| 255 |
+
processing_class=tokenizer,
|
| 256 |
+
callbacks=[LossLoggerCallback("training_losses.txt", with_timestamp=True)]
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
model.set_mtp_training(True)
|
| 260 |
+
print(" MTP training mode enabled")
|
| 261 |
+
|
| 262 |
+
print("Starting enhanced training with MTP and Horizon Loss...")
|
| 263 |
+
try:
|
| 264 |
+
print("\n Phase 1: Warmup with standard Causal LM...")
|
| 265 |
+
model.set_mtp_training(False)
|
| 266 |
+
warmup_steps = max(1, total_steps // 5)
|
| 267 |
+
trainer.args.max_steps = warmup_steps
|
| 268 |
+
trainer.train()
|
| 269 |
+
print("\n Phase 2: Full MTP + Horizon Loss training...")
|
| 270 |
+
model.set_mtp_training(True)
|
| 271 |
+
trainer.args.max_steps = total_steps
|
| 272 |
+
trainer.train(resume_from_checkpoint=True)
|
| 273 |
+
print("Enhanced training completed successfully!")
|
| 274 |
+
trainer.save_model("./enhanced-qwen3-final")
|
| 275 |
+
tokenizer.save_pretrained("./enhanced-qwen3-final")
|
| 276 |
+
training_config = {
|
| 277 |
+
"model_type": "EnhancedQwen3ForCausalLM",
|
| 278 |
+
"num_future_tokens": 3,
|
| 279 |
+
"horizon_loss_enabled": True,
|
| 280 |
+
"mtp_head_enabled": True,
|
| 281 |
+
"training_phases": ["causal_lm_warmup", "mtp_horizon_training"],
|
| 282 |
+
"total_parameters": count_parameters(model),
|
| 283 |
+
}
|
| 284 |
+
config_path = "./enhanced-qwen3-final/training_config.json"
|
| 285 |
+
with open(config_path, "w") as f:
|
| 286 |
+
json.dump(training_config, f, indent=2)
|
| 287 |
+
print(f" Enhanced model, tokenizer, and config saved!")
|
| 288 |
+
except Exception as e:
|
| 289 |
+
print(f"Enhanced training failed with error: {e}")
|
| 290 |
+
import traceback
|
| 291 |
+
traceback.print_exc()
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
print("\nmTesting enhanced generation capabilities...")
|
| 295 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 296 |
+
model.to(device)
|
| 297 |
+
model.eval()
|
| 298 |
+
try:
|
| 299 |
+
print("\n--- Standard Generation Test ---")
|
| 300 |
+
input_ids = tokenizer("<s> [C]", return_tensors="pt").input_ids.to(device)
|
| 301 |
+
with torch.no_grad():
|
| 302 |
+
model.set_mtp_training(False)
|
| 303 |
+
gen = model.generate(
|
| 304 |
+
input_ids,
|
| 305 |
+
max_length=GENERATION_CFG.get("max_length", 64),
|
| 306 |
+
top_k=GENERATION_CFG.get("top_k", 50),
|
| 307 |
+
top_p=GENERATION_CFG.get("top_p", 0.9),
|
| 308 |
+
temperature=GENERATION_CFG.get("temperature", 0.8),
|
| 309 |
+
do_sample=GENERATION_CFG.get("do_sample", True),
|
| 310 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 311 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 312 |
+
num_return_sequences=GENERATION_CFG.get("num_return_sequences", 3),
|
| 313 |
+
)
|
| 314 |
+
for i, sequence in enumerate(gen):
|
| 315 |
+
result = tokenizer.decode(sequence, skip_special_tokens=True)
|
| 316 |
+
print(f"Generated SELFIES {i+1}: {result}")
|
| 317 |
+
print("\n--- MTP Analysis Test ---")
|
| 318 |
+
model.set_mtp_training(True)
|
| 319 |
+
test_smiles = "[C]"
|
| 320 |
+
test_input = tokenizer(test_smiles, return_tensors="pt", add_special_tokens=True).to(device)
|
| 321 |
+
with torch.no_grad():
|
| 322 |
+
outputs = model(**test_input)
|
| 323 |
+
if hasattr(model.mtp_head, 'prediction_heads'):
|
| 324 |
+
hidden_states = model.model(test_input['input_ids']).last_hidden_state
|
| 325 |
+
mtp_outputs = model.mtp_head(hidden_states)
|
| 326 |
+
print(f"Input SELFIES: {test_smiles}")
|
| 327 |
+
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(test_input['input_ids'][0].tolist())}")
|
| 328 |
+
for i, (key, logits) in enumerate(mtp_outputs.items()):
|
| 329 |
+
top_tokens = torch.topk(logits[0], k=3, dim=-1)
|
| 330 |
+
print(f"\n{key} predictions:")
|
| 331 |
+
for pos in range(min(5, logits.size(1))):
|
| 332 |
+
pos_preds = []
|
| 333 |
+
for j in range(3):
|
| 334 |
+
token_id = top_tokens.indices[pos, j].item()
|
| 335 |
+
prob = torch.softmax(logits[0, pos], dim=-1)[token_id].item()
|
| 336 |
+
token = tokenizer.id_to_token.get(token_id, '<UNK>')
|
| 337 |
+
pos_preds.append(f"{token}({prob:.3f})")
|
| 338 |
+
print(f" Position {pos}: {', '.join(pos_preds)}")
|
| 339 |
+
print("\nEnhanced generation tests completed!")
|
| 340 |
+
except Exception as e:
|
| 341 |
+
print(f"Enhanced generation test failed: {e}")
|
| 342 |
+
import traceback
|
| 343 |
+
traceback.print_exc()
|
| 344 |
+
|
| 345 |
+
print("\nEnhanced Model Analysis:")
|
| 346 |
+
print(f"Total parameters: {count_parameters(model):,}")
|
| 347 |
+
mtp_params = sum(p.numel() for p in model.mtp_head.parameters() if p.requires_grad)
|
| 348 |
+
horizon_params = sum(p.numel() for p in model.horizon_loss.parameters() if p.requires_grad)
|
| 349 |
+
base_params = count_parameters(model) - mtp_params - horizon_params
|
| 350 |
+
print(f"Base model parameters: {base_params:,}")
|
| 351 |
+
print(f"MTP head parameters: {mtp_params:,}")
|
| 352 |
+
print(f"Horizon loss parameters: {horizon_params:,}")
|
| 353 |
+
print(f"Enhancement overhead: {((mtp_params + horizon_params) / base_params * 100):.2f}%")
|
| 354 |
+
print(f"\n Enhanced Model Architecture:")
|
| 355 |
+
print(f"- Base Model: Qwen3 with {config.num_hidden_layers} layers")
|
| 356 |
+
print(f"- Hidden Size: {config.hidden_size}")
|
| 357 |
+
print(f"- Attention Heads: {config.num_attention_heads}")
|
| 358 |
+
print(f"- Vocab Size: {config.vocab_size}")
|
| 359 |
+
print(f"- MTP Future Tokens: {model.mtp_head.num_future_tokens}")
|
| 360 |
+
print(f"- Horizon Loss Weights: Learnable")
|
| 361 |
+
print(f"- Training Mode: {'MTP + Horizon Loss' if model.use_mtp_training else 'Standard Causal LM'}")
|
| 362 |
+
print("\n Enhanced training pipeline completed successfully!")
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
main()
|
train_ppokl_withsa.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Refactored PPO-KL training script using ChemQ3MTP module
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from FastChemTokenizerHF import FastChemTokenizerSelfies
|
| 8 |
+
from ChemQ3MTP import ChemQ3MTP, CurriculumManager
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
print(f"🚀 Using device: {device}")
|
| 14 |
+
|
| 15 |
+
# --- Load tokenizer ---
|
| 16 |
+
tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core")
|
| 17 |
+
|
| 18 |
+
# --- Load model ---
|
| 19 |
+
model = ChemQ3MTP.from_pretrained("../pretrained/sample-e1-mtp")
|
| 20 |
+
model.tokenizer = tokenizer
|
| 21 |
+
model.to(device)
|
| 22 |
+
|
| 23 |
+
# --- RL fine-tuning setup ---
|
| 24 |
+
print("\n🎯 Phase 2: RL Fine-tuning with PPO + Curriculum Learning")
|
| 25 |
+
model.set_mtp_training(False)
|
| 26 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
|
| 27 |
+
curriculum = CurriculumManager(start_len=10, max_len=35, step_increase=5, steps_per_level=70)
|
| 28 |
+
baseline = None
|
| 29 |
+
gamma = 0.95
|
| 30 |
+
|
| 31 |
+
# Dummy input (BOS-only batch)
|
| 32 |
+
batch_size = 4
|
| 33 |
+
dummy_input = tokenizer([tokenizer.bos_token] * batch_size, return_tensors="pt", padding=True)
|
| 34 |
+
input_ids = dummy_input.input_ids.to(device)
|
| 35 |
+
|
| 36 |
+
# Training config
|
| 37 |
+
total_steps = 14000
|
| 38 |
+
checkpoint_steps = {total_steps // 4, total_steps // 2, 3 * total_steps // 4, total_steps}
|
| 39 |
+
checkpoint_dir = "./ppo_checkpoints"
|
| 40 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
# --- RL Training Loop with tqdm ---
|
| 43 |
+
for step in tqdm(range(total_steps), desc="RL Training"):
|
| 44 |
+
max_new_tokens = curriculum.get_max_new_tokens()
|
| 45 |
+
|
| 46 |
+
# === PPO Rollout ===
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
selfies_list, old_log_probs, _, old_action_probs = model.generate_with_logprobs(
|
| 49 |
+
input_ids=input_ids,
|
| 50 |
+
max_new_tokens=max_new_tokens,
|
| 51 |
+
temperature=1.0,
|
| 52 |
+
top_k=50,
|
| 53 |
+
top_p=0.95,
|
| 54 |
+
do_sample=True,
|
| 55 |
+
return_probs=True
|
| 56 |
+
)
|
| 57 |
+
old_log_probs = old_log_probs.detach()
|
| 58 |
+
old_action_probs = old_action_probs.detach()
|
| 59 |
+
|
| 60 |
+
# === PPO Update ===
|
| 61 |
+
ppo_result = model.ppo_step(
|
| 62 |
+
input_ids=input_ids,
|
| 63 |
+
old_log_probs=old_log_probs,
|
| 64 |
+
old_action_probs=old_action_probs,
|
| 65 |
+
tokenizer=tokenizer,
|
| 66 |
+
max_new_tokens=max_new_tokens,
|
| 67 |
+
# validity_weight=1.0, # only used in ChemQ3 mode
|
| 68 |
+
# lipinski_weight=1.0, # only used in ChemQ3 mode
|
| 69 |
+
entropy_weight=0.01,
|
| 70 |
+
clip_epsilon=0.2,
|
| 71 |
+
baseline=baseline,
|
| 72 |
+
reward_mode="sa", # 🔑 SA-only mode
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
loss = ppo_result['loss']
|
| 78 |
+
optimizer.zero_grad(set_to_none=True) # slightly more efficient than zeroing
|
| 79 |
+
loss.backward()
|
| 80 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 81 |
+
optimizer.step()
|
| 82 |
+
|
| 83 |
+
# === Update baseline ===
|
| 84 |
+
reward_tensor = torch.tensor(ppo_result['avg_reward'], device=device)
|
| 85 |
+
baseline = reward_tensor if baseline is None else gamma * baseline + (1 - gamma) * reward_tensor
|
| 86 |
+
|
| 87 |
+
# Curriculum update
|
| 88 |
+
curriculum.step()
|
| 89 |
+
|
| 90 |
+
# Checkpointing
|
| 91 |
+
if (step + 1) in checkpoint_steps:
|
| 92 |
+
checkpoint_path = os.path.join(checkpoint_dir, f"model_step_{step+1}")
|
| 93 |
+
model.save_pretrained(checkpoint_path)
|
| 94 |
+
tokenizer.save_pretrained(checkpoint_path)
|
| 95 |
+
torch.save({
|
| 96 |
+
'step': step + 1,
|
| 97 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 98 |
+
'baseline': baseline.item(),
|
| 99 |
+
'curriculum_state': {
|
| 100 |
+
'current_max_len': curriculum.current_max_len,
|
| 101 |
+
'step_counter': curriculum.step_counter
|
| 102 |
+
}
|
| 103 |
+
}, os.path.join(checkpoint_path, 'training_state.pt'))
|
| 104 |
+
print(f"\n💾 Checkpoint saved at step {step+1} -> {checkpoint_path}")
|
| 105 |
+
|
| 106 |
+
# Logging every 50 steps
|
| 107 |
+
if step % 50 == 0:
|
| 108 |
+
print(f"\n[RL Step {step}] "
|
| 109 |
+
f"Loss={loss.item():.4f} | "
|
| 110 |
+
f"Valid={ppo_result['validity_rate']:.3f} | "
|
| 111 |
+
f"Lipinski={ppo_result['lipinski_score']:.3f} | "
|
| 112 |
+
f"Reward={ppo_result['avg_reward']:.3f} | "
|
| 113 |
+
f"Entropy={ppo_result['entropy']:.3f} | "
|
| 114 |
+
f"EntropyW={ppo_result['entropy_weight']:.4f}")
|
| 115 |
+
|
| 116 |
+
sample_selfies = ppo_result['generated_selfies'][0][:100]
|
| 117 |
+
sample_smiles = ppo_result['generated_smiles'][0] or "Invalid"
|
| 118 |
+
print(f" Sample SELFIES: {sample_selfies}")
|
| 119 |
+
print(f" Sample SMILES: {sample_smiles}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
sample_selfies = ppo_result['generated_selfies'][0][:100]
|
| 123 |
+
sample_smiles = ppo_result['generated_smiles'][0] or "Invalid"
|
| 124 |
+
print(f" Sample SELFIES: {sample_selfies}")
|
| 125 |
+
print(f" Sample SMILES: {sample_smiles}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
print("🎉 Training complete!")
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
main()
|