Add Durrant's Lab filtering
Browse files- rl_utils.py +58 -6
rl_utils.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
# ========================
|
| 2 |
# RL_UTILS.PY
|
|
|
|
| 3 |
# Chemistry RL Training Utilities for ChemQ3-MTP
|
| 4 |
# by gbyuvd
|
| 5 |
# Patched: reward normalization, KL/entropy reset per phase,
|
| 6 |
-
# entropy target annealing, and symmetric curriculum
|
|
|
|
| 7 |
# ========================
|
| 8 |
|
| 9 |
import torch
|
|
@@ -47,6 +49,53 @@ def is_valid_smiles(smiles: str) -> bool:
|
|
| 47 |
return False
|
| 48 |
return Chem.MolFromSmiles(smiles.strip()) is not None
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# ========================
|
| 51 |
# SA CLASSIFIER
|
| 52 |
# ========================
|
|
@@ -249,7 +298,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
|
|
| 249 |
Dictionary containing individual reward components and total
|
| 250 |
"""
|
| 251 |
smiles = selfies_to_smiles(selfies_str)
|
| 252 |
-
mol = Chem.MolFromSmiles(smiles) if smiles else None
|
| 253 |
|
| 254 |
rewards = {
|
| 255 |
"validity": 1.0 if mol is not None else 0.0,
|
|
@@ -280,7 +329,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
|
|
| 280 |
def selfies_to_lipinski_reward(selfies_str: str) -> float:
|
| 281 |
"""Convert SELFIES to SMILES, then compute Lipinski reward."""
|
| 282 |
smiles = selfies_to_smiles(selfies_str)
|
| 283 |
-
if smiles is None:
|
| 284 |
return 0.0
|
| 285 |
mol = Chem.MolFromSmiles(smiles)
|
| 286 |
return compute_lipinski_reward(mol)
|
|
@@ -513,6 +562,9 @@ def batch_compute_rewards(
|
|
| 513 |
sa_rewards = []
|
| 514 |
|
| 515 |
for selfies_str in selfies_list:
|
|
|
|
|
|
|
|
|
|
| 516 |
if reward_mode == "chemq3":
|
| 517 |
r = compute_comprehensive_reward(selfies_str)
|
| 518 |
validity_vals.append(r.get('validity', 0.0))
|
|
@@ -520,13 +572,13 @@ def batch_compute_rewards(
|
|
| 520 |
total_rewards.append(r.get('total', 0.0))
|
| 521 |
|
| 522 |
elif reward_mode == "sa":
|
| 523 |
-
sa = compute_sa_reward(selfies_str)
|
| 524 |
sa_rewards.append(sa)
|
| 525 |
total_rewards.append(sa)
|
| 526 |
|
| 527 |
elif reward_mode == "mix":
|
| 528 |
r = compute_comprehensive_reward(selfies_str)
|
| 529 |
-
sa = compute_sa_reward(selfies_str)
|
| 530 |
mixed = reward_mix * r.get("total", 0.0) + (1.0 - reward_mix) * sa
|
| 531 |
|
| 532 |
total_rewards.append(mixed)
|
|
@@ -596,7 +648,7 @@ def compute_training_metrics(
|
|
| 596 |
valid_smiles = []
|
| 597 |
for selfies_str in selfies_list:
|
| 598 |
smiles = selfies_to_smiles(selfies_str)
|
| 599 |
-
if smiles and is_valid_smiles(smiles):
|
| 600 |
valid_smiles.append(smiles)
|
| 601 |
|
| 602 |
metrics["num_valid"] = len(valid_smiles)
|
|
|
|
| 1 |
# ========================
|
| 2 |
# RL_UTILS.PY
|
| 3 |
+
# v3
|
| 4 |
# Chemistry RL Training Utilities for ChemQ3-MTP
|
| 5 |
# by gbyuvd
|
| 6 |
# Patched: reward normalization, KL/entropy reset per phase,
|
| 7 |
+
# entropy target annealing, and symmetric curriculum
|
| 8 |
+
# and now with Durrant's Lab's filtering included
|
| 9 |
# ========================
|
| 10 |
|
| 11 |
import torch
|
|
|
|
| 49 |
return False
|
| 50 |
return Chem.MolFromSmiles(smiles.strip()) is not None
|
| 51 |
|
| 52 |
+
def passes_durrant_lab_filter(smiles: str) -> bool:
|
| 53 |
+
"""
|
| 54 |
+
Apply Durant's lab filter to remove improbable substructures.
|
| 55 |
+
Returns True if molecule passes the filter (is acceptable), False otherwise.
|
| 56 |
+
"""
|
| 57 |
+
if not smiles or not isinstance(smiles, str):
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 61 |
+
if mol is None:
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
# Define SMARTS patterns for problematic substructures
|
| 65 |
+
problematic_patterns = [
|
| 66 |
+
"C=[N-]", # Carbon double bonded to negative nitrogen
|
| 67 |
+
"[N-]C=[N+]", # Nitrogen anion bonded to nitrogen cation
|
| 68 |
+
"[nH+]c[n-]", # Aromatic nitrogen cation adjacent to nitrogen anion
|
| 69 |
+
"[#7+]~[#7+]", # Positive nitrogen connected to positive nitrogen
|
| 70 |
+
"[#7-]~[#7-]", # Negative nitrogen connected to negative nitrogen
|
| 71 |
+
"[!#7]~[#7+]~[#7-]~[!#7]", # Bridge: non-nitrogen - pos nitrogen - neg nitrogen - non-nitrogen
|
| 72 |
+
"[#5]", # Boron atoms
|
| 73 |
+
"O=[PH](=O)([#8])([#8])", # Phosphoryl with hydroxyls
|
| 74 |
+
"N=c1cc[#7]c[#7]1", # Nitrogen in aromatic ring with another nitrogen
|
| 75 |
+
"[$([NX2H1]),$([NX3H2])]=C[$([OH]),$([O-])]", # N=CH-OH or N=CH-O-
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# Check for metals (excluding common biologically relevant ions like Na+, K+, Ca2+, Mg2+)
|
| 79 |
+
metal_exclusions = {11, 12, 19, 20} # Na, Mg, K, Ca
|
| 80 |
+
for atom in mol.GetAtoms():
|
| 81 |
+
atomic_num = atom.GetAtomicNum()
|
| 82 |
+
if atomic_num > 20 and atomic_num not in metal_exclusions:
|
| 83 |
+
return False # Metal present that's not in the exclusion list
|
| 84 |
+
|
| 85 |
+
# Check for each problematic pattern
|
| 86 |
+
for pattern in problematic_patterns:
|
| 87 |
+
try:
|
| 88 |
+
patt_mol = Chem.MolFromSmarts(pattern)
|
| 89 |
+
if patt_mol is not None:
|
| 90 |
+
matches = mol.GetSubstructMatches(patt_mol)
|
| 91 |
+
if matches:
|
| 92 |
+
return False # Found problematic substructure
|
| 93 |
+
except:
|
| 94 |
+
# If SMARTS parsing fails, skip this pattern
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
return True # Passed all checks
|
| 98 |
+
|
| 99 |
# ========================
|
| 100 |
# SA CLASSIFIER
|
| 101 |
# ========================
|
|
|
|
| 298 |
Dictionary containing individual reward components and total
|
| 299 |
"""
|
| 300 |
smiles = selfies_to_smiles(selfies_str)
|
| 301 |
+
mol = Chem.MolFromSmiles(smiles) if smiles and passes_durrant_lab_filter(smiles) else None
|
| 302 |
|
| 303 |
rewards = {
|
| 304 |
"validity": 1.0 if mol is not None else 0.0,
|
|
|
|
| 329 |
def selfies_to_lipinski_reward(selfies_str: str) -> float:
|
| 330 |
"""Convert SELFIES to SMILES, then compute Lipinski reward."""
|
| 331 |
smiles = selfies_to_smiles(selfies_str)
|
| 332 |
+
if smiles is None or not passes_durrant_lab_filter(smiles):
|
| 333 |
return 0.0
|
| 334 |
mol = Chem.MolFromSmiles(smiles)
|
| 335 |
return compute_lipinski_reward(mol)
|
|
|
|
| 562 |
sa_rewards = []
|
| 563 |
|
| 564 |
for selfies_str in selfies_list:
|
| 565 |
+
smiles = selfies_to_smiles(selfies_str)
|
| 566 |
+
passes_filter = passes_durrant_lab_filter(smiles) if smiles else False
|
| 567 |
+
|
| 568 |
if reward_mode == "chemq3":
|
| 569 |
r = compute_comprehensive_reward(selfies_str)
|
| 570 |
validity_vals.append(r.get('validity', 0.0))
|
|
|
|
| 572 |
total_rewards.append(r.get('total', 0.0))
|
| 573 |
|
| 574 |
elif reward_mode == "sa":
|
| 575 |
+
sa = compute_sa_reward(selfies_str) if passes_filter else 0.0
|
| 576 |
sa_rewards.append(sa)
|
| 577 |
total_rewards.append(sa)
|
| 578 |
|
| 579 |
elif reward_mode == "mix":
|
| 580 |
r = compute_comprehensive_reward(selfies_str)
|
| 581 |
+
sa = compute_sa_reward(selfies_str) if passes_filter else 0.0
|
| 582 |
mixed = reward_mix * r.get("total", 0.0) + (1.0 - reward_mix) * sa
|
| 583 |
|
| 584 |
total_rewards.append(mixed)
|
|
|
|
| 648 |
valid_smiles = []
|
| 649 |
for selfies_str in selfies_list:
|
| 650 |
smiles = selfies_to_smiles(selfies_str)
|
| 651 |
+
if smiles and is_valid_smiles(smiles) and passes_durrant_lab_filter(smiles):
|
| 652 |
valid_smiles.append(smiles)
|
| 653 |
|
| 654 |
metrics["num_valid"] = len(valid_smiles)
|