banulaperera commited on
Commit
d0d0961
·
verified ·
1 Parent(s): 55705fc

Enhance inference: tokenizer, classifier, scorer

Browse files

Refactors and extends the METANO inference stack: adds a rule-based MolecularClassifier for coarse molecule typing, hardens CharacterLevelChemicalTokenizer (marker handling, normalization, vocab markers), and introduces a SymbolicScorer with heuristic penalties and balanced-bracket checks. Implements BalancedBracketsLogitsProcessor to constrain generation, updates predict_neurosymbolic to a hybrid decode/repair flow (rescore neural candidates with symbolic heuristics, run constrained repair rounds), normalizes candidate scoring and deduplication, and adjusts defaults (e.g. sym_lambda, generation modes). MetanoModel now accepts a pretrained T5 instance from HF and model loading wraps that accordingly. Also updates README formatting and test_run outputs to show richer test diagnostics.

Files changed (3) hide show
  1. README.md +37 -27
  2. metano_inference.py +576 -106
  3. test_run.py +4 -3
README.md CHANGED
@@ -56,16 +56,16 @@ strings into human‑readable IUPAC names**.
56
 
57
  It is intended for:
58
 
59
- - Cheminformatics researchers\
60
- - Computational chemists\
61
- - Chemical database maintainers\
62
  - AI-driven chemistry pipelines
63
 
64
  The model is particularly useful for molecules containing:
65
 
66
- - Transition metals\
67
- - Alkali metals\
68
- - Lanthanides\
69
  - Actinides
70
 
71
  ------------------------------------------------------------------------
@@ -74,9 +74,9 @@ The model is particularly useful for molecules containing:
74
 
75
  The model is **not intended for:**
76
 
77
- - Generating molecular 3D structures\
78
- - Predicting chemical properties\
79
- - Reaction prediction\
80
  - Translating formats other than InChI (e.g., SMILES) directly to
81
  IUPAC without conversion
82
 
@@ -98,13 +98,13 @@ bracket balancing and basic chemical syntax constraints, it remains a
98
 
99
  Potential issues include:
100
 
101
- - Hallucinated nomenclature for unseen structures\
102
- - Reduced accuracy for extremely large molecules\
103
  - Errors for polymeric or highly unusual compounds
104
 
105
  Training limits:
106
 
107
- - **Maximum InChI length:** 400 characters\
108
  - **Maximum IUPAC length:** 150 characters
109
 
110
  ------------------------------------------------------------------------
@@ -147,13 +147,23 @@ out = predict_neurosymbolic(
147
  inchi=test_inchi,
148
  scorer=scorer,
149
  num_candidates=5,
150
- sym_lambda=1.0,
151
  repair_num_candidates=5,
152
  max_repair_rounds=1
153
  )
154
 
155
- print("Predicted IUPAC:", out["predicted_iupac"])
156
- print("Combined Score:", out["combined_score"])
 
 
 
 
 
 
 
 
 
 
 
157
  ```
158
 
159
  ------------------------------------------------------------------------
@@ -167,8 +177,8 @@ covering diverse chemical classes.
167
 
168
  Training subsets include:
169
 
170
- - \~294K inorganic combinations\
171
- - \~123K organometallic compounds\
172
  - \~82K coordination complexes
173
 
174
  Both **standard and reconnected (/r) InChI strings** were included.
@@ -206,13 +216,13 @@ were added using a custom **CharacterLevelChemicalTokenizer**.
206
 
207
  ## Training Hyperparameters
208
 
209
- - **Training regime:** fp16 mixed precision (AMP)\
210
- - **Optimizer:** AdamW\
211
- - **Learning Rate:** 3e‑4 with 10% linear warmup and linear decay\
212
- - **Weight Decay:** 0.01\
213
- - **Batch Size:** 128 (effective via gradient accumulation = 2)\
214
- - **Max Input Length:** 410 tokens\
215
- - **Max Output Length:** 160 tokens\
216
  - **Gradient Clipping:** 1.0
217
 
218
  ------------------------------------------------------------------------
@@ -224,9 +234,9 @@ were added using a custom **CharacterLevelChemicalTokenizer**.
224
  Evaluation was conducted on a **held‑out test split** containing a
225
  balanced distribution of:
226
 
227
- - **Inorganic Compounds:** METANO achieves a Top-1 accuracy of 0.378, outperforming previously reported results of 0.14.\
228
- - **Organometallic Compounds:** METANO achieves a Top-1 accuracy of 0.364, outperforming previously reported results of 0.20.\
229
- - **Co-ordination Compounds:** METANO achieves a Top-1 accuracy of 0.394.\
230
  - **Top-K Decoding** Additional gains are observed using Top-K decoding, reaching Top-5 accuracies of 0.481 (inorganic), 0.488 (organometallic) and 0.521 (Co-ordination).
231
 
232
  ------------------------------------------------------------------------
 
56
 
57
  It is intended for:
58
 
59
+ - Cheminformatics researchers
60
+ - Computational chemists
61
+ - Chemical database maintainers
62
  - AI-driven chemistry pipelines
63
 
64
  The model is particularly useful for molecules containing:
65
 
66
+ - Transition metals
67
+ - Alkali metals
68
+ - Lanthanides
69
  - Actinides
70
 
71
  ------------------------------------------------------------------------
 
74
 
75
  The model is **not intended for:**
76
 
77
+ - Generating molecular 3D structures
78
+ - Predicting chemical properties
79
+ - Reaction prediction
80
  - Translating formats other than InChI (e.g., SMILES) directly to
81
  IUPAC without conversion
82
 
 
98
 
99
  Potential issues include:
100
 
101
+ - Hallucinated nomenclature for unseen structures
102
+ - Reduced accuracy for extremely large molecules
103
  - Errors for polymeric or highly unusual compounds
104
 
105
  Training limits:
106
 
107
+ - **Maximum InChI length:** 400 characters
108
  - **Maximum IUPAC length:** 150 characters
109
 
110
  ------------------------------------------------------------------------
 
147
  inchi=test_inchi,
148
  scorer=scorer,
149
  num_candidates=5,
 
150
  repair_num_candidates=5,
151
  max_repair_rounds=1
152
  )
153
 
154
+ print("=== TEST RESULTS ===")
155
+ print(f"Predicted IUPAC: {out['predicted_iupac']}")
156
+ print(f"Hard Fail Triggered: {out['hard_fail']}")
157
+ print(f"Combined Score: {out['combined_score']:.3f}")
158
+ print(f"Symbolic Score: {out['symbolic_score']:.3f}")
159
+ print(f"Neural Score: {out['neural_score']:.3f}")
160
+
161
+ if out['reasons']:
162
+ print(f"Penalty Reasons: {out['reasons']}")
163
+
164
+ print("\nTop Candidates:")
165
+ for cand in out["candidates"][1:]:
166
+ print(f" [{cand['combined']:.3f}] {cand['text']}")
167
  ```
168
 
169
  ------------------------------------------------------------------------
 
177
 
178
  Training subsets include:
179
 
180
+ - \~294K inorganic combinations
181
+ - \~123K organometallic compounds
182
  - \~82K coordination complexes
183
 
184
  Both **standard and reconnected (/r) InChI strings** were included.
 
216
 
217
  ## Training Hyperparameters
218
 
219
+ - **Training regime:** fp16 mixed precision (AMP)
220
+ - **Optimizer:** AdamW
221
+ - **Learning Rate:** 3e‑4 with 10% linear warmup and linear decay
222
+ - **Weight Decay:** 0.01
223
+ - **Batch Size:** 128 (effective via gradient accumulation = 2)
224
+ - **Max Input Length:** 410 tokens
225
+ - **Max Output Length:** 160 tokens
226
  - **Gradient Clipping:** 1.0
227
 
228
  ------------------------------------------------------------------------
 
234
  Evaluation was conducted on a **held‑out test split** containing a
235
  balanced distribution of:
236
 
237
+ - **Inorganic Compounds:** METANO achieves a Top-1 accuracy of 0.378, outperforming previously reported results of 0.14.
238
+ - **Organometallic Compounds:** METANO achieves a Top-1 accuracy of 0.364, outperforming previously reported results of 0.20.
239
+ - **Co-ordination Compounds:** METANO achieves a Top-1 accuracy of 0.394.
240
  - **Top-K Decoding** Additional gains are observed using Top-K decoding, reaching Top-5 accuracies of 0.481 (inorganic), 0.488 (organometallic) and 0.521 (Co-ordination).
241
 
242
  ------------------------------------------------------------------------
metano_inference.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import re
 
3
  import unicodedata
4
  import numpy as np
5
  import torch
@@ -13,29 +14,130 @@ from transformers.generation.logits_process import LogitsProcessor
13
  # Define device globally for inference
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
 
16
  @dataclass
17
  class ModelConfig:
18
  """Configuration for METANO Model"""
 
19
  model_name: str = "t5-small"
20
  max_input_length: int = 410
21
  max_output_length: int = 160
22
  metal_elements: List[str] = field(
23
  default_factory=lambda: [
24
- "Li", "Na", "K", "Rb", "Cs", "Be", "Mg", "Ca", "Sr", "Ba", "Sc", "Ti",
25
- "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Y", "Zr", "Nb", "Mo",
26
- "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "Hf", "Ta", "W", "Re", "Os", "Ir",
27
- "Pt", "Au", "Hg", "Al", "Ga", "In", "Tl", "Sn", "Pb", "Bi",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ]
29
  )
30
 
 
31
  class MolecularClassifier:
 
 
32
  def __init__(self):
33
- self.transition_metals = {22, 23, 24, 25, 26, 27, 28, 29, 30, 40, 41, 42, 43, 44, 45, 46, 47, 48, 72, 73, 74, 75, 76, 77, 78, 79, 80}
34
- self.main_group_metals = {3, 4, 11, 12, 13, 19, 20, 31, 37, 38, 49, 50, 55, 56, 81, 82, 83}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  self.lanthanides = set(range(57, 72))
36
  self.actinides = set(range(89, 104))
37
- self.all_metals = self.transition_metals | self.main_group_metals | self.lanthanides | self.actinides
 
 
 
 
 
38
 
 
 
39
  self.organometallic_patterns = [
40
  "[Fe,Co,Ni,Cr,Mn,Mo,W,Ru,Os,Rh,Ir]-C=O",
41
  "[Fe,Co,Ni,Cr,Mn,Mo,W,Ru,Os,Rh,Ir]-[C-]#[O+]",
@@ -43,6 +145,7 @@ class MolecularClassifier:
43
  "[Fe,Co,Ni,Ru,Rh,Os,Ir,Ti,V,Cr,Mn,Zr,Mo,W]~c1ccccc1",
44
  ]
45
 
 
46
  self.compiled_organometallic = []
47
  for pattern in self.organometallic_patterns:
48
  try:
@@ -50,16 +153,37 @@ class MolecularClassifier:
50
  if mol is not None:
51
  self.compiled_organometallic.append(mol)
52
  except Exception:
 
53
  pass
54
 
55
  def classify_molecule(self, mol: Chem.Mol) -> Dict[str, any]:
 
 
 
 
 
 
 
56
  try:
57
  has_carbon = self._has_element(mol, 6)
58
  has_metal = self._has_metals(mol)
59
  classification = self._classify_by_composition(mol, has_carbon, has_metal)
60
- metal_info = self._extract_metals(mol) if has_metal else {"metal_atomic_nums": set(), "metal_symbols": [], "primary_metal": None}
61
- return {"classification": classification, "has_metal": has_metal, "primary_metal": metal_info["primary_metal"]}
 
 
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
 
63
  return {"classification": "error", "error": str(e)}
64
 
65
  def _has_element(self, mol: Chem.Mol, atomic_num: int) -> bool:
@@ -76,26 +200,42 @@ class MolecularClassifier:
76
  if z in self.all_metals:
77
  metal_atomic_nums.add(z)
78
  metal_symbols.append(atom.GetSymbol())
79
-
80
  seen = set()
81
  metal_symbols = [m for m in metal_symbols if not (m in seen or seen.add(m))]
82
- return {"metal_atomic_nums": metal_atomic_nums, "metal_symbols": metal_symbols, "primary_metal": metal_symbols[0] if metal_symbols else None}
 
 
 
 
83
 
84
  def _has_metal_carbon_bond(self, mol: Chem.Mol) -> bool:
85
  for bond in mol.GetBonds():
86
- a1_num, a2_num = bond.GetBeginAtom().GetAtomicNum(), bond.GetEndAtom().GetAtomicNum()
87
- if (a1_num in self.all_metals and a2_num == 6) or (a1_num == 6 and a2_num in self.all_metals):
 
 
 
 
 
88
  return True
89
  return False
90
 
91
  def _recover_organometallic_by_smarts(self, mol: Chem.Mol) -> bool:
92
- return any(mol.HasSubstructMatch(pattern) for pattern in self.compiled_organometallic)
 
 
93
 
94
  def _has_metal_heteroatom_bond(self, mol: Chem.Mol) -> bool:
95
  donor_atoms = {7, 8, 9, 15, 16, 17, 35, 53}
96
  for bond in mol.GetBonds():
97
- z1, z2 = bond.GetBeginAtom().GetAtomicNum(), bond.GetEndAtom().GetAtomicNum()
98
- if (z1 in self.all_metals and z2 in donor_atoms) or (z2 in self.all_metals and z1 in donor_atoms):
 
 
 
 
 
99
  return True
100
  return False
101
 
@@ -109,12 +249,20 @@ class MolecularClassifier:
109
  return True
110
  return False
111
 
112
- def _classify_by_composition(self, mol: Chem.Mol, has_carbon: bool, has_metal: bool) -> str:
 
 
113
  if has_metal and has_carbon:
114
- if self._has_metal_carbon_bond(mol) or self._recover_organometallic_by_smarts(mol):
 
 
115
  return "organometallic"
116
  elif self._has_metal_heteroatom_bond(mol):
117
- return "inorganic" if self._is_simple_inorganic_salt(mol) else "coordination"
 
 
 
 
118
  return "inorganic"
119
  elif (has_metal and not has_carbon) or (not has_carbon and not has_metal):
120
  return "inorganic"
@@ -122,20 +270,96 @@ class MolecularClassifier:
122
  return "organic"
123
  return "unclassified"
124
 
 
125
  class CharacterLevelChemicalTokenizer:
126
  def __init__(self, config):
127
  self.config = config
128
  self.metals = set(self.config.metal_elements)
129
 
130
- self.control_tokens = ["<ORGANIC>", "<ORGANOMETALLIC>", "<INORGANIC>", "<COORDINATION>", "<STANDARD_INCHI>", "<RECONNECTED_INCHI>"]
131
- self.structural_markers = ["<METAL>"] + [f"<METAL_{metal.upper()}>" for metal in sorted(self.metals)]
 
 
 
 
 
 
 
 
 
132
  self.specials = ["<PAD>", "<UNK>", "<START>", "<END>"]
133
 
134
  base_chars = [
135
- " ", "-", "=", "#", "+", "(", ")", "[", "]", "{", "}", "/", "\\", ",", ".", ":", ";", "@", "*", "&", "|", "'", '"',
136
- *"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", *"0123456789",
137
- "α", "β", "γ", "δ", "ε", "ζ", "η", "θ", "κ", "λ", "μ", "ν", "ξ", "π", "ρ", "σ", "τ", "φ", "χ", "ψ", "ω", "Δ", "Λ",
138
- "", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹", "⁺", "⁻", "₀", "₁", "₂", "₃", "₄", "₅", "₆", "₇", "₈", "₉",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  ]
140
 
141
  all_markers = self.control_tokens + self.structural_markers
@@ -151,23 +375,31 @@ class CharacterLevelChemicalTokenizer:
151
  self.bos_token_id = self.token2idx["<START>"]
152
  self.eos_token_id = self.token2idx["<END>"]
153
 
154
- self.marker_pattern = re.compile("|".join(map(re.escape, self.sorted_markers))) if self.sorted_markers else None
 
 
 
 
155
 
156
  def _normalize(self, text: str) -> str:
157
- if text is None: return ""
 
158
  text = unicodedata.normalize("NFKC", str(text)).replace("\u00a0", " ").strip()
159
  return " ".join(text.split())
160
 
161
  def tokenize(self, text: str) -> List[str]:
162
  text = self._normalize(text)
163
- if not text: return []
 
164
  tokens, pos = [], 0
165
  if self.marker_pattern:
166
  for m in self.marker_pattern.finditer(text):
167
- if m.start() > pos: tokens.extend(list(text[pos : m.start()]))
 
168
  tokens.append(m.group())
169
  pos = m.end()
170
- if pos < len(text): tokens.extend(list(text[pos:]))
 
171
  return tokens
172
 
173
  def encode(self, text: str, max_length: int, is_target: bool = False) -> Dict:
@@ -185,21 +417,32 @@ class CharacterLevelChemicalTokenizer:
185
  padded_ids = input_ids + [-100] * pad_len
186
  else:
187
  padded_ids = input_ids + [self.pad_token_id] * pad_len
188
-
189
  attention_mask = [1] * len(input_ids) + [0] * pad_len
190
  return {
191
  "input_ids": torch.tensor(padded_ids, dtype=torch.long),
192
  "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
193
  }
194
 
195
- def decode(self, token_ids: Union[torch.Tensor, List[int]], skip_special_tokens: bool = True) -> str:
196
- if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist()
 
 
 
 
 
197
  out_tokens = []
198
  for idx in token_ids:
199
- if idx == self.eos_token_id or idx == -100: break
200
- if idx == self.pad_token_id: continue
 
 
201
  tok = self.idx2token.get(idx, "<UNK>")
202
- if skip_special_tokens and (tok in self.specials or tok in self.control_tokens or tok in self.structural_markers):
 
 
 
 
203
  continue
204
  out_tokens.append(tok)
205
  return "".join(out_tokens).strip()
@@ -207,155 +450,376 @@ class CharacterLevelChemicalTokenizer:
207
  def get_vocab_size(self) -> int:
208
  return self.vocab_size
209
 
210
- def preprocess_inchi(self, inchi: str, category: Optional[str] = None, has_metal: bool = False, primary_metal: Optional[str] = None) -> str:
211
- if not inchi: return ""
 
 
 
 
 
 
 
212
  control_prefix = []
213
  category_lower = category.lower() if category else "organic"
214
 
215
- if "organometallic" in category_lower: control_prefix.append("<ORGANOMETALLIC>")
216
- elif "coordination" in category_lower: control_prefix.append("<COORDINATION>")
217
- elif "inorganic" in category_lower: control_prefix.append("<INORGANIC>")
218
- else: control_prefix.append("<ORGANIC>")
 
 
 
 
219
 
220
- control_prefix.append("<RECONNECTED_INCHI>" if "/r" in inchi else "<STANDARD_INCHI>")
 
 
221
 
222
  if has_metal:
223
- metal_tok = f"<METAL_{primary_metal.upper()}>" if primary_metal else "<METAL>"
224
- control_prefix.append(metal_tok if metal_tok in self.token2idx else "<METAL>")
 
 
 
 
225
 
226
  return "".join(control_prefix) + inchi
227
 
228
  def preprocess_iupac(self, iupac: str) -> str:
229
  return self._normalize(iupac)
230
 
 
231
  class MetanoModel(nn.Module):
232
- def __init__(self, config: ModelConfig, classifier: MolecularClassifier, pretrained_t5: Optional[T5ForConditionalGeneration] = None):
 
 
 
 
 
233
  super().__init__()
234
  self.config = config
235
  self.classifier = classifier
236
  self.tokenizer = CharacterLevelChemicalTokenizer(config)
237
-
238
- if pretrained_t5 is not None:
239
- self.model = pretrained_t5
240
- else:
241
- t5_config = T5Config.from_pretrained(config.model_name)
242
- t5_config.vocab_size = self.tokenizer.get_vocab_size()
243
- t5_config.pad_token_id = self.tokenizer.pad_token_id
244
- t5_config.eos_token_id = self.tokenizer.eos_token_id
245
- t5_config.decoder_start_token_id = self.tokenizer.pad_token_id
246
- self.model = T5ForConditionalGeneration(config=t5_config)
247
- self.model.resize_token_embeddings(self.tokenizer.get_vocab_size())
248
-
249
  self.model.config.use_cache = True
250
 
 
251
  @dataclass
252
  class SymbolicResult:
253
  score: float
254
  hard_fail: bool
255
  reasons: List[str]
256
 
 
257
  class SymbolicScorer:
 
 
258
  def __init__(self, metals: List[str]):
259
  self.metals = [m.lower() for m in metals]
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def _balanced(self, s: str) -> Tuple[bool, List[str]]:
 
262
  stack, pairs = [], {")": "(", "]": "[", "}": "{"}
263
  opens, closes = set(pairs.values()), set(pairs.keys())
264
  for ch in s:
265
- if ch in opens: stack.append(ch)
 
266
  elif ch in closes:
267
- if not stack or stack[-1] != pairs[ch]: return False, [f"Unbalanced bracket: found '{ch}' without matching '{pairs[ch]}'"]
 
 
 
268
  stack.pop()
269
- if stack: return False, [f"Unbalanced bracket: missing closers for {stack}"]
 
270
  return True, []
271
 
272
  def score(self, src: str, pred: str) -> SymbolicResult:
 
 
 
 
 
273
  reasons = []
 
 
 
 
274
  ok_balance, balance_reasons = self._balanced(pred)
275
  reasons.extend(balance_reasons)
276
-
277
- if len(pred.strip()) == 0: reasons.append("Empty prediction")
278
- if re.search(r"[,\\.\\-]{3,}", pred): reasons.append("Repeated punctuation")
279
- if " " in pred: reasons.append("Double spaces")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  hard_fail = (not ok_balance) or ("Empty prediction" in reasons)
282
- score = 0.5 if not reasons else sum([-2.0 if "Unbalanced" in r else -3.0 if "Empty" in r else -0.5 for r in reasons])
283
  return SymbolicResult(score=score, hard_fail=hard_fail, reasons=reasons)
284
 
 
285
  class BalancedBracketsLogitsProcessor(LogitsProcessor):
 
 
286
  def __init__(self, tok2id: Dict[str, int]):
287
- self.ids = {ch: tok2id[ch] for ch in ["(", ")", "[", "]", "{", "}"] if ch in tok2id}
 
 
 
288
 
289
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
290
- if not self.ids: return scores
 
 
 
291
  for b in range(input_ids.size(0)):
292
  seq = input_ids[b].tolist()
293
- open_par, close_par = seq.count(self.ids.get("(", -1)), seq.count(self.ids.get(")", -1))
294
- open_sq, close_sq = seq.count(self.ids.get("[", -1)), seq.count(self.ids.get("]", -1))
295
- if close_par >= open_par and ")" in self.ids: scores[b, self.ids[")"]] = -float("inf")
296
- if close_sq >= open_sq and "]" in self.ids: scores[b, self.ids["]"]] = -float("inf")
 
 
 
 
 
 
 
 
297
  return scores
298
 
 
299
  @torch.no_grad()
300
  def predict_neurosymbolic(
301
- model: MetanoModel, inchi: str, scorer: SymbolicScorer,
302
- category: Optional[str] = None, has_metal: Optional[bool] = None, primary_metal: Optional[str] = None,
303
- num_candidates: int = 8, sym_lambda: float = 1.0, repair_num_candidates: int = 16, max_repair_rounds: int = 3
 
 
 
 
 
 
 
304
  ):
 
 
 
 
 
 
305
  model.model.eval()
306
 
 
307
  if category is None or has_metal is None:
308
  mol = Chem.MolFromInchi(inchi)
309
  if mol is not None:
310
  classification = model.classifier.classify_molecule(mol)
311
- if category is None: category = classification.get("classification", "organic")
312
- if has_metal is None: has_metal = classification.get("has_metal", False)
313
- if primary_metal is None and has_metal: primary_metal = classification.get("primary_metal")
 
 
 
314
  else:
315
- if category is None: category = "organic"
316
- if has_metal is None: has_metal = False
 
 
317
 
 
318
  category = category or "organic"
319
- src = model.tokenizer.preprocess_inchi(inchi, category=category, has_metal=has_metal, primary_metal=primary_metal)
320
-
 
 
321
  enc = model.tokenizer.encode(src, model.config.max_input_length)
322
  input_ids = enc["input_ids"].unsqueeze(0).to(device)
323
  attention_mask = enc["attention_mask"].unsqueeze(0).to(device)
324
 
 
325
  def _dedup_key(s: str) -> str:
326
- if not s: return ""
 
327
  s = " ".join(s.strip().lower().split())
328
  return re.sub(r"\s*([(),\[\]{}\-+/=.:;·])\s*", r"\1", s)
329
 
 
 
 
 
 
 
 
 
 
330
  def _generate(ncand: int, use_constraints: bool, mode: str = "beam"):
331
  kwargs = dict(
332
- input_ids=input_ids, attention_mask=attention_mask, max_length=model.config.max_output_length,
333
- num_beams=ncand, num_return_sequences=ncand, early_stopping=True,
334
- pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id,
335
- return_dict_in_generate=True, output_scores=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  )
337
- if use_constraints: kwargs["logits_processor"] = [BalancedBracketsLogitsProcessor(model.tokenizer.token2idx)]
338
- if mode == "sample": kwargs.update(do_sample=True, top_p=0.92, temperature=0.8, num_beams=1)
339
-
340
- out = model.model.generate(**kwargs) if device.type != "cuda" else \
341
- torch.autocast(device_type="cuda", dtype=torch.float16)(model.model.generate)(**kwargs)
342
 
343
- preds = [model.tokenizer.decode(seq, skip_special_tokens=True) for seq in out.sequences]
344
- neural_scores = out.sequences_scores.detach().float().cpu().numpy() if hasattr(out, "sequences_scores") else np.zeros(ncand)
 
 
 
 
 
 
 
345
  return preds, neural_scores
346
 
 
347
  pool = {}
 
348
  def _add_to_pool(preds, nscores):
349
  for p, ns in zip(preds, nscores):
350
  sym = scorer.score(src, p)
351
- combined = float(ns) + sym_lambda * float(sym.score)
 
352
  key = _dedup_key(p)
353
  if key not in pool or combined > pool[key][0]:
354
- pool[key] = (combined, float(ns), float(sym.score), sym.hard_fail, sym.reasons, p)
355
-
 
 
 
 
 
 
 
 
356
  preds1, ns1 = _generate(num_candidates, use_constraints=False, mode="beam")
357
  _add_to_pool(preds1, ns1)
358
-
 
359
  best = sorted(pool.values(), key=lambda x: x[0], reverse=True)[0]
360
  repair_modes, repair_round = ["beam", "diverse", "sample"], 0
361
 
@@ -363,16 +827,22 @@ def predict_neurosymbolic(
363
  mode = repair_modes[min(repair_round, len(repair_modes) - 1)]
364
  preds2, ns2 = _generate(repair_num_candidates, use_constraints=True, mode=mode)
365
  _add_to_pool(preds2, ns2)
366
- best = sorted(pool.values(), key=lambda x: x[0], reverse=True)[0]
367
  repair_round += 1
368
 
369
  ranked_all = sorted(pool.values(), key=lambda x: x[0], reverse=True)
370
  return {
371
- "inchi": inchi, "predicted_iupac": best[5], "neural_score": best[1], "symbolic_score": best[2],
372
- "combined_score": best[0], "hard_fail": best[3], "reasons": best[4],
 
 
 
 
 
373
  "candidates": [{"text": r[5], "combined": r[0]} for r in ranked_all[:10]],
374
  }
375
 
 
376
  def load_model_from_hf(repo_id: str) -> MetanoModel:
377
  """
378
  Downloads and loads the METANO T5 model directly from the Hugging Face Hub.
@@ -380,13 +850,13 @@ def load_model_from_hf(repo_id: str) -> MetanoModel:
380
  print(f"Loading METANO model from Hugging Face Hub: {repo_id}")
381
  config = ModelConfig()
382
  classifier = MolecularClassifier()
383
-
384
  # Load the underlying T5 model weights and config from the Hub
385
  t5_model = T5ForConditionalGeneration.from_pretrained(repo_id)
386
-
387
  # Wrap it in the custom MetanoModel architecture
388
  model = MetanoModel(config, classifier, pretrained_t5=t5_model)
389
  model.to(device)
390
  print("Model successfully loaded to device:", device)
391
-
392
- return model
 
1
  import os
2
  import re
3
+ import math
4
  import unicodedata
5
  import numpy as np
6
  import torch
 
14
  # Define device globally for inference
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+
18
  @dataclass
19
  class ModelConfig:
20
  """Configuration for METANO Model"""
21
+
22
  model_name: str = "t5-small"
23
  max_input_length: int = 410
24
  max_output_length: int = 160
25
  metal_elements: List[str] = field(
26
  default_factory=lambda: [
27
+ "Li",
28
+ "Na",
29
+ "K",
30
+ "Rb",
31
+ "Cs",
32
+ "Be",
33
+ "Mg",
34
+ "Ca",
35
+ "Sr",
36
+ "Ba",
37
+ "Sc",
38
+ "Ti",
39
+ "V",
40
+ "Cr",
41
+ "Mn",
42
+ "Fe",
43
+ "Co",
44
+ "Ni",
45
+ "Cu",
46
+ "Zn",
47
+ "Y",
48
+ "Zr",
49
+ "Nb",
50
+ "Mo",
51
+ "Tc",
52
+ "Ru",
53
+ "Rh",
54
+ "Pd",
55
+ "Ag",
56
+ "Cd",
57
+ "Hf",
58
+ "Ta",
59
+ "W",
60
+ "Re",
61
+ "Os",
62
+ "Ir",
63
+ "Pt",
64
+ "Au",
65
+ "Hg",
66
+ "Al",
67
+ "Ga",
68
+ "In",
69
+ "Tl",
70
+ "Sn",
71
+ "Pb",
72
+ "Bi",
73
  ]
74
  )
75
 
76
+
77
  class MolecularClassifier:
78
+ """Rule-based molecular category classifier used to condition generation."""
79
+
80
  def __init__(self):
81
+ # Atomic-number groups used by classifier heuristics.
82
+ self.transition_metals = {
83
+ 22,
84
+ 23,
85
+ 24,
86
+ 25,
87
+ 26,
88
+ 27,
89
+ 28,
90
+ 29,
91
+ 30,
92
+ 40,
93
+ 41,
94
+ 42,
95
+ 43,
96
+ 44,
97
+ 45,
98
+ 46,
99
+ 47,
100
+ 48,
101
+ 72,
102
+ 73,
103
+ 74,
104
+ 75,
105
+ 76,
106
+ 77,
107
+ 78,
108
+ 79,
109
+ 80,
110
+ }
111
+ self.main_group_metals = {
112
+ 3,
113
+ 4,
114
+ 11,
115
+ 12,
116
+ 13,
117
+ 19,
118
+ 20,
119
+ 31,
120
+ 37,
121
+ 38,
122
+ 49,
123
+ 50,
124
+ 55,
125
+ 56,
126
+ 81,
127
+ 82,
128
+ 83,
129
+ }
130
  self.lanthanides = set(range(57, 72))
131
  self.actinides = set(range(89, 104))
132
+ self.all_metals = (
133
+ self.transition_metals
134
+ | self.main_group_metals
135
+ | self.lanthanides
136
+ | self.actinides
137
+ )
138
 
139
+ # SMARTS recovery patterns for common organometallic motifs that may not
140
+ # be captured by simple direct metal–carbon checks.
141
  self.organometallic_patterns = [
142
  "[Fe,Co,Ni,Cr,Mn,Mo,W,Ru,Os,Rh,Ir]-C=O",
143
  "[Fe,Co,Ni,Cr,Mn,Mo,W,Ru,Os,Rh,Ir]-[C-]#[O+]",
 
145
  "[Fe,Co,Ni,Ru,Rh,Os,Ir,Ti,V,Cr,Mn,Zr,Mo,W]~c1ccccc1",
146
  ]
147
 
148
+ # Compile once for faster repeated substructure checks during inference.
149
  self.compiled_organometallic = []
150
  for pattern in self.organometallic_patterns:
151
  try:
 
153
  if mol is not None:
154
  self.compiled_organometallic.append(mol)
155
  except Exception:
156
+ # Ignore malformed SMARTS entries instead of failing startup.
157
  pass
158
 
159
  def classify_molecule(self, mol: Chem.Mol) -> Dict[str, any]:
160
+ """
161
+ Return coarse category metadata for prompt conditioning:
162
+ - classification (organic / inorganic / organometallic / coordination)
163
+ - has_metal (bool)
164
+ - primary_metal (first detected metal symbol, if any)
165
+ """
166
+
167
  try:
168
  has_carbon = self._has_element(mol, 6)
169
  has_metal = self._has_metals(mol)
170
  classification = self._classify_by_composition(mol, has_carbon, has_metal)
171
+ metal_info = (
172
+ self._extract_metals(mol)
173
+ if has_metal
174
+ else {
175
+ "metal_atomic_nums": set(),
176
+ "metal_symbols": [],
177
+ "primary_metal": None,
178
+ }
179
+ )
180
+ return {
181
+ "classification": classification,
182
+ "has_metal": has_metal,
183
+ "primary_metal": metal_info["primary_metal"],
184
+ }
185
  except Exception as e:
186
+ # Keep inference robust if RDKit parsing/classification fails
187
  return {"classification": "error", "error": str(e)}
188
 
189
  def _has_element(self, mol: Chem.Mol, atomic_num: int) -> bool:
 
200
  if z in self.all_metals:
201
  metal_atomic_nums.add(z)
202
  metal_symbols.append(atom.GetSymbol())
203
+
204
  seen = set()
205
  metal_symbols = [m for m in metal_symbols if not (m in seen or seen.add(m))]
206
+ return {
207
+ "metal_atomic_nums": metal_atomic_nums,
208
+ "metal_symbols": metal_symbols,
209
+ "primary_metal": metal_symbols[0] if metal_symbols else None,
210
+ }
211
 
212
  def _has_metal_carbon_bond(self, mol: Chem.Mol) -> bool:
213
  for bond in mol.GetBonds():
214
+ a1_num, a2_num = (
215
+ bond.GetBeginAtom().GetAtomicNum(),
216
+ bond.GetEndAtom().GetAtomicNum(),
217
+ )
218
+ if (a1_num in self.all_metals and a2_num == 6) or (
219
+ a1_num == 6 and a2_num in self.all_metals
220
+ ):
221
  return True
222
  return False
223
 
224
  def _recover_organometallic_by_smarts(self, mol: Chem.Mol) -> bool:
225
+ return any(
226
+ mol.HasSubstructMatch(pattern) for pattern in self.compiled_organometallic
227
+ )
228
 
229
  def _has_metal_heteroatom_bond(self, mol: Chem.Mol) -> bool:
230
  donor_atoms = {7, 8, 9, 15, 16, 17, 35, 53}
231
  for bond in mol.GetBonds():
232
+ z1, z2 = (
233
+ bond.GetBeginAtom().GetAtomicNum(),
234
+ bond.GetEndAtom().GetAtomicNum(),
235
+ )
236
+ if (z1 in self.all_metals and z2 in donor_atoms) or (
237
+ z2 in self.all_metals and z1 in donor_atoms
238
+ ):
239
  return True
240
  return False
241
 
 
249
  return True
250
  return False
251
 
252
+ def _classify_by_composition(
253
+ self, mol: Chem.Mol, has_carbon: bool, has_metal: bool
254
+ ) -> str:
255
  if has_metal and has_carbon:
256
+ if self._has_metal_carbon_bond(
257
+ mol
258
+ ) or self._recover_organometallic_by_smarts(mol):
259
  return "organometallic"
260
  elif self._has_metal_heteroatom_bond(mol):
261
+ return (
262
+ "inorganic"
263
+ if self._is_simple_inorganic_salt(mol)
264
+ else "coordination"
265
+ )
266
  return "inorganic"
267
  elif (has_metal and not has_carbon) or (not has_carbon and not has_metal):
268
  return "inorganic"
 
270
  return "organic"
271
  return "unclassified"
272
 
273
+
274
  class CharacterLevelChemicalTokenizer:
275
  def __init__(self, config):
276
  self.config = config
277
  self.metals = set(self.config.metal_elements)
278
 
279
+ self.control_tokens = [
280
+ "<ORGANIC>",
281
+ "<ORGANOMETALLIC>",
282
+ "<INORGANIC>",
283
+ "<COORDINATION>",
284
+ "<STANDARD_INCHI>",
285
+ "<RECONNECTED_INCHI>",
286
+ ]
287
+ self.structural_markers = ["<METAL>"] + [
288
+ f"<METAL_{metal.upper()}>" for metal in sorted(self.metals)
289
+ ]
290
  self.specials = ["<PAD>", "<UNK>", "<START>", "<END>"]
291
 
292
  base_chars = [
293
+ " ",
294
+ "-",
295
+ "=",
296
+ "#",
297
+ "+",
298
+ "(",
299
+ ")",
300
+ "[",
301
+ "]",
302
+ "{",
303
+ "}",
304
+ "/",
305
+ "\\",
306
+ ",",
307
+ ".",
308
+ ":",
309
+ ";",
310
+ "@",
311
+ "*",
312
+ "&",
313
+ "|",
314
+ "'",
315
+ '"',
316
+ *"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
317
+ *"0123456789",
318
+ "α",
319
+ "β",
320
+ "γ",
321
+ "δ",
322
+ "ε",
323
+ "ζ",
324
+ "η",
325
+ "θ",
326
+ "κ",
327
+ "λ",
328
+ "μ",
329
+ "ν",
330
+ "ξ",
331
+ "π",
332
+ "ρ",
333
+ "σ",
334
+ "τ",
335
+ "φ",
336
+ "χ",
337
+ "ψ",
338
+ "ω",
339
+ "Δ",
340
+ "Λ",
341
+ "⁰",
342
+ "¹",
343
+ "²",
344
+ "³",
345
+ "⁴",
346
+ "⁵",
347
+ "⁶",
348
+ "⁷",
349
+ "⁸",
350
+ "⁹",
351
+ "⁺",
352
+ "⁻",
353
+ "₀",
354
+ "₁",
355
+ "₂",
356
+ "₃",
357
+ "₄",
358
+ "₅",
359
+ "₆",
360
+ "₇",
361
+ "₈",
362
+ "₉",
363
  ]
364
 
365
  all_markers = self.control_tokens + self.structural_markers
 
375
  self.bos_token_id = self.token2idx["<START>"]
376
  self.eos_token_id = self.token2idx["<END>"]
377
 
378
+ self.marker_pattern = (
379
+ re.compile("|".join(map(re.escape, self.sorted_markers)))
380
+ if self.sorted_markers
381
+ else None
382
+ )
383
 
384
  def _normalize(self, text: str) -> str:
385
+ if text is None:
386
+ return ""
387
  text = unicodedata.normalize("NFKC", str(text)).replace("\u00a0", " ").strip()
388
  return " ".join(text.split())
389
 
390
  def tokenize(self, text: str) -> List[str]:
391
  text = self._normalize(text)
392
+ if not text:
393
+ return []
394
  tokens, pos = [], 0
395
  if self.marker_pattern:
396
  for m in self.marker_pattern.finditer(text):
397
+ if m.start() > pos:
398
+ tokens.extend(list(text[pos : m.start()]))
399
  tokens.append(m.group())
400
  pos = m.end()
401
+ if pos < len(text):
402
+ tokens.extend(list(text[pos:]))
403
  return tokens
404
 
405
  def encode(self, text: str, max_length: int, is_target: bool = False) -> Dict:
 
417
  padded_ids = input_ids + [-100] * pad_len
418
  else:
419
  padded_ids = input_ids + [self.pad_token_id] * pad_len
420
+
421
  attention_mask = [1] * len(input_ids) + [0] * pad_len
422
  return {
423
  "input_ids": torch.tensor(padded_ids, dtype=torch.long),
424
  "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
425
  }
426
 
427
+ def decode(
428
+ self,
429
+ token_ids: Union[torch.Tensor, List[int]],
430
+ skip_special_tokens: bool = True,
431
+ ) -> str:
432
+ if isinstance(token_ids, torch.Tensor):
433
+ token_ids = token_ids.tolist()
434
  out_tokens = []
435
  for idx in token_ids:
436
+ if idx == self.eos_token_id or idx == -100:
437
+ break
438
+ if idx == self.pad_token_id:
439
+ continue
440
  tok = self.idx2token.get(idx, "<UNK>")
441
+ if skip_special_tokens and (
442
+ tok in self.specials
443
+ or tok in self.control_tokens
444
+ or tok in self.structural_markers
445
+ ):
446
  continue
447
  out_tokens.append(tok)
448
  return "".join(out_tokens).strip()
 
450
  def get_vocab_size(self) -> int:
451
  return self.vocab_size
452
 
453
+ def preprocess_inchi(
454
+ self,
455
+ inchi: str,
456
+ category: str,
457
+ has_metal: bool = False,
458
+ primary_metal: Optional[str] = None,
459
+ ) -> str:
460
+ if not inchi:
461
+ return ""
462
  control_prefix = []
463
  category_lower = category.lower() if category else "organic"
464
 
465
+ if "organometallic" in category_lower:
466
+ control_prefix.append("<ORGANOMETALLIC>")
467
+ elif "coordination" in category_lower:
468
+ control_prefix.append("<COORDINATION>")
469
+ elif "inorganic" in category_lower:
470
+ control_prefix.append("<INORGANIC>")
471
+ else:
472
+ control_prefix.append("<ORGANIC>")
473
 
474
+ control_prefix.append(
475
+ "<RECONNECTED_INCHI>" if "/r" in inchi else "<STANDARD_INCHI>"
476
+ )
477
 
478
  if has_metal:
479
+ metal_tok = (
480
+ f"<METAL_{primary_metal.upper()}>" if primary_metal else "<METAL>"
481
+ )
482
+ control_prefix.append(
483
+ metal_tok if metal_tok in self.token2idx else "<METAL>"
484
+ )
485
 
486
  return "".join(control_prefix) + inchi
487
 
488
  def preprocess_iupac(self, iupac: str) -> str:
489
  return self._normalize(iupac)
490
 
491
+
492
  class MetanoModel(nn.Module):
493
+ def __init__(
494
+ self,
495
+ config: ModelConfig,
496
+ classifier: MolecularClassifier,
497
+ pretrained_t5: T5ForConditionalGeneration,
498
+ ):
499
  super().__init__()
500
  self.config = config
501
  self.classifier = classifier
502
  self.tokenizer = CharacterLevelChemicalTokenizer(config)
503
+ self.model = pretrained_t5
 
 
 
 
 
 
 
 
 
 
 
504
  self.model.config.use_cache = True
505
 
506
+
507
  @dataclass
508
  class SymbolicResult:
509
  score: float
510
  hard_fail: bool
511
  reasons: List[str]
512
 
513
+
514
  class SymbolicScorer:
515
+ """Symbolic validator/penalizer for generated IUPAC candidates."""
516
+
517
  def __init__(self, metals: List[str]):
518
  self.metals = [m.lower() for m in metals]
519
 
520
+ # Minimal lexical hints used by heuristic checks.
521
+ self.VALID_SUFFIXES = [
522
+ "ane",
523
+ "ene",
524
+ "yne",
525
+ "ol",
526
+ "one",
527
+ "al",
528
+ "amine",
529
+ "amide",
530
+ "acid",
531
+ "ate",
532
+ "ether",
533
+ "ester",
534
+ "thiol",
535
+ "imine",
536
+ "benzene",
537
+ ]
538
+
539
+ # Prefixes expected to be attached/hyphenated consistently in IUPAC-like text.
540
+ self.MULTIPLICATIVE_PREFIXES = [
541
+ "mono",
542
+ "di",
543
+ "tri",
544
+ "tetra",
545
+ "penta",
546
+ "hexa",
547
+ "hepta",
548
+ "octa",
549
+ "nona",
550
+ "deca",
551
+ "bis",
552
+ "tris",
553
+ "tetrakis",
554
+ "pentakis",
555
+ "hexakis",
556
+ ]
557
+
558
+ # Penalty table used to compute symbolic score in [0, 1] via 1 - total_penalty.
559
+ self.PENALTY_WEIGHTS = {
560
+ "Empty prediction": 1.0,
561
+ "Unbalanced bracket": 1.0,
562
+ "Double spaces": 0.1,
563
+ "Repeated punctuation": 0.4,
564
+ "Repeated comma": 0.3,
565
+ "Invalid hyphen usage": 0.3,
566
+ "Locant without substituent": 0.6,
567
+ "Prediction too short": 0.5,
568
+ "Prediction too long": 0.4,
569
+ "Repeated token": 0.3,
570
+ "Invalid spacing after multiplicative prefix": 0.2,
571
+ }
572
+
573
  def _balanced(self, s: str) -> Tuple[bool, List[str]]:
574
+ """Check (), [], {} bracket balance and nesting correctness."""
575
  stack, pairs = [], {")": "(", "]": "[", "}": "{"}
576
  opens, closes = set(pairs.values()), set(pairs.keys())
577
  for ch in s:
578
+ if ch in opens:
579
+ stack.append(ch)
580
  elif ch in closes:
581
+ if not stack or stack[-1] != pairs[ch]:
582
+ return False, [
583
+ f"Unbalanced bracket: found '{ch}' without matching '{pairs[ch]}'"
584
+ ]
585
  stack.pop()
586
+ if stack:
587
+ return False, [f"Unbalanced bracket: missing closers for {stack}"]
588
  return True, []
589
 
590
  def score(self, src: str, pred: str) -> SymbolicResult:
591
+ """
592
+ Score candidate text with rule-based penalties.
593
+ hard_fail is set for structurally invalid outputs (empty or unbalanced).
594
+ """
595
+ pred = pred.strip()
596
  reasons = []
597
+ if len(pred) == 0:
598
+ reasons.append("Empty prediction")
599
+
600
+ # Structural check first (most important hard-fail signal).
601
  ok_balance, balance_reasons = self._balanced(pred)
602
  reasons.extend(balance_reasons)
603
+
604
+ # Surface-form sanity checks.
605
+ # Double spaces
606
+ if " " in pred:
607
+ reasons.append("Double spaces")
608
+ # Repeated punctuation
609
+ if re.search(r"[,\.\-]{3,}", pred):
610
+ reasons.append("Repeated punctuation")
611
+ # Repeated commas
612
+ if ",," in pred:
613
+ reasons.append("Repeated comma")
614
+ # Invalid hyphen
615
+ if re.search(r"--|,-|-,", pred):
616
+ reasons.append("Invalid hyphen usage")
617
+ # Invalid locant
618
+ if re.search(r"\b\d+(,\d+)*-$", pred):
619
+ reasons.append("Locant without substituent")
620
+ # Length sanity
621
+ if len(pred) < 4:
622
+ reasons.append("Prediction too short")
623
+ if len(pred) > 200:
624
+ reasons.append("Prediction too long")
625
+ # Repeated words
626
+ if re.search(r"\b(\w+)\s+\1\b", pred.lower()):
627
+ reasons.append("Repeated token")
628
+
629
+ # Prefix spacing check for IUPAC-like style.
630
+ for prefix in self.MULTIPLICATIVE_PREFIXES:
631
+ if re.search(rf"\b{prefix}\s+", pred.lower()):
632
+ reasons.append(
633
+ f"Invalid spacing after multiplicative prefix '{prefix}'"
634
+ )
635
+
636
+ # Convert reason strings to numeric penalty.
637
+ total_penalty = 0.0
638
+ for reason in reasons:
639
+ for key, weight in self.PENALTY_WEIGHTS.items():
640
+ if key in reason:
641
+ total_penalty += weight
642
 
643
  hard_fail = (not ok_balance) or ("Empty prediction" in reasons)
644
+ score = max(0.0, 1.0 - total_penalty)
645
  return SymbolicResult(score=score, hard_fail=hard_fail, reasons=reasons)
646
 
647
+
648
  class BalancedBracketsLogitsProcessor(LogitsProcessor):
649
+ """Generation-time constraint: suppress unmatched closing brackets."""
650
+
651
  def __init__(self, tok2id: Dict[str, int]):
652
+ # Keep only bracket tokens present in tokenizer vocabulary.
653
+ self.ids = {
654
+ ch: tok2id[ch] for ch in ["(", ")", "[", "]", "{", "}"] if ch in tok2id
655
+ }
656
 
657
+ def __call__(
658
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
659
+ ) -> torch.FloatTensor:
660
+ if not self.ids:
661
+ return scores
662
  for b in range(input_ids.size(0)):
663
  seq = input_ids[b].tolist()
664
+ open_par, close_par = seq.count(self.ids.get("(", -1)), seq.count(
665
+ self.ids.get(")", -1)
666
+ )
667
+ open_sq, close_sq = seq.count(self.ids.get("[", -1)), seq.count(
668
+ self.ids.get("]", -1)
669
+ )
670
+
671
+ # If closers already meet/exceed openers, block emitting more closers.
672
+ if close_par >= open_par and ")" in self.ids:
673
+ scores[b, self.ids[")"]] = -float("inf")
674
+ if close_sq >= open_sq and "]" in self.ids:
675
+ scores[b, self.ids["]"]] = -float("inf")
676
  return scores
677
 
678
+
679
  @torch.no_grad()
680
  def predict_neurosymbolic(
681
+ model: MetanoModel,
682
+ inchi: str,
683
+ scorer: SymbolicScorer,
684
+ category: Optional[str] = None,
685
+ has_metal: Optional[bool] = None,
686
+ primary_metal: Optional[str] = None,
687
+ num_candidates: int = 8,
688
+ sym_lambda: float = 0.5,
689
+ repair_num_candidates: int = 16,
690
+ max_repair_rounds: int = 3,
691
  ):
692
+ """
693
+ Hybrid decoding:
694
+ 1) Generate candidates with neural model.
695
+ 2) Rescore with symbolic heuristics.
696
+ 3) If best candidate hard-fails, run constrained repair rounds.
697
+ """
698
  model.model.eval()
699
 
700
+ # Derive molecule metadata when not provided by caller.
701
  if category is None or has_metal is None:
702
  mol = Chem.MolFromInchi(inchi)
703
  if mol is not None:
704
  classification = model.classifier.classify_molecule(mol)
705
+ if category is None:
706
+ category = classification.get("classification", "organic")
707
+ if has_metal is None:
708
+ has_metal = classification.get("has_metal", False)
709
+ if primary_metal is None and has_metal:
710
+ primary_metal = classification.get("primary_metal")
711
  else:
712
+ if category is None:
713
+ category = "organic"
714
+ if has_metal is None:
715
+ has_metal = False
716
 
717
+ # Prepare source sequence with control and structural markers.
718
  category = category or "organic"
719
+ src = model.tokenizer.preprocess_inchi(
720
+ inchi, category=category, has_metal=has_metal, primary_metal=primary_metal
721
+ )
722
+
723
  enc = model.tokenizer.encode(src, model.config.max_input_length)
724
  input_ids = enc["input_ids"].unsqueeze(0).to(device)
725
  attention_mask = enc["attention_mask"].unsqueeze(0).to(device)
726
 
727
+ # Normalize prediction text for de-duplication across beam/sample outputs.
728
  def _dedup_key(s: str) -> str:
729
+ if not s:
730
+ return ""
731
  s = " ".join(s.strip().lower().split())
732
  return re.sub(r"\s*([(),\[\]{}\-+/=.:;·])\s*", r"\1", s)
733
 
734
+ # HF requirement: num_beams must be divisible by num_beam_groups.
735
+ def _choose_beam_groups(ncand: int, max_groups: int = 4) -> int:
736
+ gmax = min(max_groups, ncand)
737
+ for g in range(gmax, 1, -1):
738
+ if ncand % g == 0:
739
+ return g
740
+ return 1
741
+
742
+ # Shared candidate generation helper for beam/diverse/sample modes.
743
  def _generate(ncand: int, use_constraints: bool, mode: str = "beam"):
744
  kwargs = dict(
745
+ input_ids=input_ids,
746
+ attention_mask=attention_mask,
747
+ max_length=model.config.max_output_length,
748
+ num_beams=ncand,
749
+ num_return_sequences=ncand,
750
+ early_stopping=True,
751
+ pad_token_id=model.tokenizer.pad_token_id,
752
+ eos_token_id=model.tokenizer.eos_token_id,
753
+ return_dict_in_generate=True,
754
+ output_scores=True,
755
+ )
756
+ # ---- Constraint tweaks ----
757
+ if use_constraints:
758
+ kwargs["logits_processor"] = [
759
+ BalancedBracketsLogitsProcessor(model.tokenizer.token2idx)
760
+ ]
761
+ # ---- Mode tweaks ----
762
+ if mode == "diverse":
763
+ g = _choose_beam_groups(ncand, max_groups=4)
764
+ # Only enable diverse beams if we can form >1 groups
765
+ if g > 1:
766
+ kwargs.update(
767
+ num_beam_groups=g,
768
+ diversity_penalty=0.2,
769
+ )
770
+ elif mode == "sample":
771
+ # Sampling fallback: get ncand independent samples
772
+ kwargs.update(
773
+ do_sample=True,
774
+ top_p=0.92,
775
+ temperature=0.8,
776
+ num_beams=1,
777
+ num_return_sequences=ncand,
778
+ )
779
+
780
+ out = (
781
+ model.model.generate(**kwargs)
782
+ if device.type != "cuda"
783
+ else torch.autocast(device_type="cuda", dtype=torch.float16)(
784
+ model.model.generate
785
+ )(**kwargs)
786
  )
 
 
 
 
 
787
 
788
+ preds = [
789
+ model.tokenizer.decode(seq, skip_special_tokens=True)
790
+ for seq in out.sequences
791
+ ]
792
+ neural_scores = (
793
+ out.sequences_scores.detach().float().cpu().numpy()
794
+ if hasattr(out, "sequences_scores")
795
+ else np.zeros(ncand)
796
+ )
797
  return preds, neural_scores
798
 
799
+ # Pool entry format: (combined, neural_prob_like, symbolic, hard_fail, reasons, text)
800
  pool = {}
801
+
802
  def _add_to_pool(preds, nscores):
803
  for p, ns in zip(preds, nscores):
804
  sym = scorer.score(src, p)
805
+ ns = math.exp(ns) # convert log-like beam score to positive scale
806
+ combined = (sym_lambda * float(ns)) + ((1 - sym_lambda) * float(sym.score))
807
  key = _dedup_key(p)
808
  if key not in pool or combined > pool[key][0]:
809
+ pool[key] = (
810
+ combined,
811
+ float(ns),
812
+ float(sym.score),
813
+ sym.hard_fail,
814
+ sym.reasons,
815
+ p,
816
+ )
817
+
818
+ # Initial unconstrained generation.
819
  preds1, ns1 = _generate(num_candidates, use_constraints=False, mode="beam")
820
  _add_to_pool(preds1, ns1)
821
+
822
+ # If top result fails symbolic hard checks, run constrained repair rounds.
823
  best = sorted(pool.values(), key=lambda x: x[0], reverse=True)[0]
824
  repair_modes, repair_round = ["beam", "diverse", "sample"], 0
825
 
 
827
  mode = repair_modes[min(repair_round, len(repair_modes) - 1)]
828
  preds2, ns2 = _generate(repair_num_candidates, use_constraints=True, mode=mode)
829
  _add_to_pool(preds2, ns2)
830
+ best = sorted(pool.values(), key=lambda x: x[0], reverse=True)[0] # update best after adding repairs
831
  repair_round += 1
832
 
833
  ranked_all = sorted(pool.values(), key=lambda x: x[0], reverse=True)
834
  return {
835
+ "inchi": inchi,
836
+ "predicted_iupac": best[5],
837
+ "neural_score": best[1],
838
+ "symbolic_score": best[2],
839
+ "combined_score": best[0],
840
+ "hard_fail": best[3],
841
+ "reasons": best[4],
842
  "candidates": [{"text": r[5], "combined": r[0]} for r in ranked_all[:10]],
843
  }
844
 
845
+
846
  def load_model_from_hf(repo_id: str) -> MetanoModel:
847
  """
848
  Downloads and loads the METANO T5 model directly from the Hugging Face Hub.
 
850
  print(f"Loading METANO model from Hugging Face Hub: {repo_id}")
851
  config = ModelConfig()
852
  classifier = MolecularClassifier()
853
+
854
  # Load the underlying T5 model weights and config from the Hub
855
  t5_model = T5ForConditionalGeneration.from_pretrained(repo_id)
856
+
857
  # Wrap it in the custom MetanoModel architecture
858
  model = MetanoModel(config, classifier, pretrained_t5=t5_model)
859
  model.to(device)
860
  print("Model successfully loaded to device:", device)
861
+
862
+ return model
test_run.py CHANGED
@@ -11,7 +11,7 @@ def main():
11
  scorer = SymbolicScorer(metals=config.metal_elements)
12
 
13
  # A sample coordination/organometallic InChI from your notebook
14
- test_inchi = "InChI=1/C15H16N2O3S.Na/c1-10-3-4-12(9-11(10)2)15(18)17-21(19,20)14-7-5-13(16)6-8-14;/h3-9H,16H2,1-2H3,(H,17,18);/q;+1"
15
 
16
  print(f"\nRunning prediction for InChI:\n{test_inchi}\n")
17
 
@@ -20,7 +20,6 @@ def main():
20
  inchi=test_inchi,
21
  scorer=scorer,
22
  num_candidates=5,
23
- sym_lambda=1.0,
24
  repair_num_candidates=5,
25
  max_repair_rounds=1
26
  )
@@ -29,12 +28,14 @@ def main():
29
  print(f"Predicted IUPAC: {out['predicted_iupac']}")
30
  print(f"Hard Fail Triggered: {out['hard_fail']}")
31
  print(f"Combined Score: {out['combined_score']:.3f}")
 
 
32
 
33
  if out['reasons']:
34
  print(f"Penalty Reasons: {out['reasons']}")
35
 
36
  print("\nTop Candidates:")
37
- for cand in out["candidates"][:3]:
38
  print(f" [{cand['combined']:.3f}] {cand['text']}")
39
 
40
  if __name__ == "__main__":
 
11
  scorer = SymbolicScorer(metals=config.metal_elements)
12
 
13
  # A sample coordination/organometallic InChI from your notebook
14
+ test_inchi = "InChI=1/Fe.Na.H2O4S.H2O.H/c;;1-5(2,3)4;;/h;;(H2,1,2,3,4);1H2;/q;+1;;;-1"
15
 
16
  print(f"\nRunning prediction for InChI:\n{test_inchi}\n")
17
 
 
20
  inchi=test_inchi,
21
  scorer=scorer,
22
  num_candidates=5,
 
23
  repair_num_candidates=5,
24
  max_repair_rounds=1
25
  )
 
28
  print(f"Predicted IUPAC: {out['predicted_iupac']}")
29
  print(f"Hard Fail Triggered: {out['hard_fail']}")
30
  print(f"Combined Score: {out['combined_score']:.3f}")
31
+ print(f"Symbolic Score: {out['symbolic_score']:.3f}")
32
+ print(f"Neural Score: {out['neural_score']:.3f}")
33
 
34
  if out['reasons']:
35
  print(f"Penalty Reasons: {out['reasons']}")
36
 
37
  print("\nTop Candidates:")
38
+ for cand in out["candidates"][1:]:
39
  print(f" [{cand['combined']:.3f}] {cand['text']}")
40
 
41
  if __name__ == "__main__":