gbyuvd commited on
Commit
e367722
·
verified ·
1 Parent(s): b70883e

Add Durrant's Lab filtering

Browse files
Files changed (1) hide show
  1. rl_utils.py +58 -6
rl_utils.py CHANGED
@@ -1,9 +1,11 @@
1
  # ========================
2
  # RL_UTILS.PY
 
3
  # Chemistry RL Training Utilities for ChemQ3-MTP
4
  # by gbyuvd
5
  # Patched: reward normalization, KL/entropy reset per phase,
6
- # entropy target annealing, and symmetric curriculum (kept old naming).
 
7
  # ========================
8
 
9
  import torch
@@ -47,6 +49,53 @@ def is_valid_smiles(smiles: str) -> bool:
47
  return False
48
  return Chem.MolFromSmiles(smiles.strip()) is not None
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # ========================
51
  # SA CLASSIFIER
52
  # ========================
@@ -249,7 +298,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
249
  Dictionary containing individual reward components and total
250
  """
251
  smiles = selfies_to_smiles(selfies_str)
252
- mol = Chem.MolFromSmiles(smiles) if smiles else None
253
 
254
  rewards = {
255
  "validity": 1.0 if mol is not None else 0.0,
@@ -280,7 +329,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
280
  def selfies_to_lipinski_reward(selfies_str: str) -> float:
281
  """Convert SELFIES to SMILES, then compute Lipinski reward."""
282
  smiles = selfies_to_smiles(selfies_str)
283
- if smiles is None:
284
  return 0.0
285
  mol = Chem.MolFromSmiles(smiles)
286
  return compute_lipinski_reward(mol)
@@ -513,6 +562,9 @@ def batch_compute_rewards(
513
  sa_rewards = []
514
 
515
  for selfies_str in selfies_list:
 
 
 
516
  if reward_mode == "chemq3":
517
  r = compute_comprehensive_reward(selfies_str)
518
  validity_vals.append(r.get('validity', 0.0))
@@ -520,13 +572,13 @@ def batch_compute_rewards(
520
  total_rewards.append(r.get('total', 0.0))
521
 
522
  elif reward_mode == "sa":
523
- sa = compute_sa_reward(selfies_str)
524
  sa_rewards.append(sa)
525
  total_rewards.append(sa)
526
 
527
  elif reward_mode == "mix":
528
  r = compute_comprehensive_reward(selfies_str)
529
- sa = compute_sa_reward(selfies_str)
530
  mixed = reward_mix * r.get("total", 0.0) + (1.0 - reward_mix) * sa
531
 
532
  total_rewards.append(mixed)
@@ -596,7 +648,7 @@ def compute_training_metrics(
596
  valid_smiles = []
597
  for selfies_str in selfies_list:
598
  smiles = selfies_to_smiles(selfies_str)
599
- if smiles and is_valid_smiles(smiles):
600
  valid_smiles.append(smiles)
601
 
602
  metrics["num_valid"] = len(valid_smiles)
 
1
  # ========================
2
  # RL_UTILS.PY
3
+ # v3
4
  # Chemistry RL Training Utilities for ChemQ3-MTP
5
  # by gbyuvd
6
  # Patched: reward normalization, KL/entropy reset per phase,
7
+ # entropy target annealing, and symmetric curriculum
8
+ # and now with Durrant's Lab's filtering included
9
  # ========================
10
 
11
  import torch
 
49
  return False
50
  return Chem.MolFromSmiles(smiles.strip()) is not None
51
 
52
+ def passes_durrant_lab_filter(smiles: str) -> bool:
53
+ """
54
+ Apply Durant's lab filter to remove improbable substructures.
55
+ Returns True if molecule passes the filter (is acceptable), False otherwise.
56
+ """
57
+ if not smiles or not isinstance(smiles, str):
58
+ return False
59
+
60
+ mol = Chem.MolFromSmiles(smiles)
61
+ if mol is None:
62
+ return False
63
+
64
+ # Define SMARTS patterns for problematic substructures
65
+ problematic_patterns = [
66
+ "C=[N-]", # Carbon double bonded to negative nitrogen
67
+ "[N-]C=[N+]", # Nitrogen anion bonded to nitrogen cation
68
+ "[nH+]c[n-]", # Aromatic nitrogen cation adjacent to nitrogen anion
69
+ "[#7+]~[#7+]", # Positive nitrogen connected to positive nitrogen
70
+ "[#7-]~[#7-]", # Negative nitrogen connected to negative nitrogen
71
+ "[!#7]~[#7+]~[#7-]~[!#7]", # Bridge: non-nitrogen - pos nitrogen - neg nitrogen - non-nitrogen
72
+ "[#5]", # Boron atoms
73
+ "O=[PH](=O)([#8])([#8])", # Phosphoryl with hydroxyls
74
+ "N=c1cc[#7]c[#7]1", # Nitrogen in aromatic ring with another nitrogen
75
+ "[$([NX2H1]),$([NX3H2])]=C[$([OH]),$([O-])]", # N=CH-OH or N=CH-O-
76
+ ]
77
+
78
+ # Check for metals (excluding common biologically relevant ions like Na+, K+, Ca2+, Mg2+)
79
+ metal_exclusions = {11, 12, 19, 20} # Na, Mg, K, Ca
80
+ for atom in mol.GetAtoms():
81
+ atomic_num = atom.GetAtomicNum()
82
+ if atomic_num > 20 and atomic_num not in metal_exclusions:
83
+ return False # Metal present that's not in the exclusion list
84
+
85
+ # Check for each problematic pattern
86
+ for pattern in problematic_patterns:
87
+ try:
88
+ patt_mol = Chem.MolFromSmarts(pattern)
89
+ if patt_mol is not None:
90
+ matches = mol.GetSubstructMatches(patt_mol)
91
+ if matches:
92
+ return False # Found problematic substructure
93
+ except:
94
+ # If SMARTS parsing fails, skip this pattern
95
+ continue
96
+
97
+ return True # Passed all checks
98
+
99
  # ========================
100
  # SA CLASSIFIER
101
  # ========================
 
298
  Dictionary containing individual reward components and total
299
  """
300
  smiles = selfies_to_smiles(selfies_str)
301
+ mol = Chem.MolFromSmiles(smiles) if smiles and passes_durrant_lab_filter(smiles) else None
302
 
303
  rewards = {
304
  "validity": 1.0 if mol is not None else 0.0,
 
329
  def selfies_to_lipinski_reward(selfies_str: str) -> float:
330
  """Convert SELFIES to SMILES, then compute Lipinski reward."""
331
  smiles = selfies_to_smiles(selfies_str)
332
+ if smiles is None or not passes_durrant_lab_filter(smiles):
333
  return 0.0
334
  mol = Chem.MolFromSmiles(smiles)
335
  return compute_lipinski_reward(mol)
 
562
  sa_rewards = []
563
 
564
  for selfies_str in selfies_list:
565
+ smiles = selfies_to_smiles(selfies_str)
566
+ passes_filter = passes_durrant_lab_filter(smiles) if smiles else False
567
+
568
  if reward_mode == "chemq3":
569
  r = compute_comprehensive_reward(selfies_str)
570
  validity_vals.append(r.get('validity', 0.0))
 
572
  total_rewards.append(r.get('total', 0.0))
573
 
574
  elif reward_mode == "sa":
575
+ sa = compute_sa_reward(selfies_str) if passes_filter else 0.0
576
  sa_rewards.append(sa)
577
  total_rewards.append(sa)
578
 
579
  elif reward_mode == "mix":
580
  r = compute_comprehensive_reward(selfies_str)
581
+ sa = compute_sa_reward(selfies_str) if passes_filter else 0.0
582
  mixed = reward_mix * r.get("total", 0.0) + (1.0 - reward_mix) * sa
583
 
584
  total_rewards.append(mixed)
 
648
  valid_smiles = []
649
  for selfies_str in selfies_list:
650
  smiles = selfies_to_smiles(selfies_str)
651
+ if smiles and is_valid_smiles(smiles) and passes_durrant_lab_filter(smiles):
652
  valid_smiles.append(smiles)
653
 
654
  metrics["num_valid"] = len(valid_smiles)