File size: 10,596 Bytes
d1e65ce 0f19e3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# molecular_constraint_solver.py
# FINAL VERSION with strict connectivity scaffolding
import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass
import re
@dataclass
class MolecularConstraint:
constraint_type: str
value: any
operator: str = '=='
class MolecularConstraintEncoder:
def __init__(self, max_atoms=30):
self.max_atoms = max_atoms
self.max_bonds = max_atoms * (max_atoms - 1) // 2
self.var_offset = 1
self.atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'P', 'H', 'None']
self.atom_var_start = self.var_offset
self.var_offset += self.max_atoms * len(self.atom_types)
self.bond_existence_var_start = self.var_offset
self.var_offset += self.max_bonds
self.bond_types = ['single', 'double', 'triple']
self.bond_type_var_start = self.var_offset
self.var_offset += self.max_bonds * len(self.bond_types)
self.ring_var_start = self.var_offset
self.var_offset += self.max_atoms
self.max_rings = 10
self.aromatic_var_start = self.var_offset
self.var_offset += self.max_rings
self.functional_groups = ['nitro', 'azide', 'peroxide', 'aldehyde', 'ketone', 'carboxyl', 'amine', 'amide', 'ester', 'ether', 'thiol', 'sulfone', 'phosphate', 'hydroxyl', 'halogen', 'cyano', 'isocyanate', 'epoxide', 'lactone', 'quinone']
self.group_var_start = self.var_offset
self.var_offset += len(self.functional_groups)
self.mw_thresholds = list(range(100, 600, 10))
self.mw_var_start = self.var_offset
self.var_offset += len(self.mw_thresholds)
def atom_type_var(self, atom_idx, atom_type):
return self.atom_var_start + atom_idx * len(self.atom_types) + self.atom_types.index(atom_type)
def bond_existence_var(self, i, j):
if i == j: return -1
if i > j: i, j = j, i
idx = int(i * (self.max_atoms - (i + 1) / 2.0) + (j - i - 1))
return self.bond_existence_var_start + idx
def atom_exists_lit(self, atom_idx):
return -self.atom_type_var(atom_idx, 'None')
def ring_var(self, idx): return self.ring_var_start + idx
def aromatic_ring_var(self, idx): return self.aromatic_var_start + idx
def functional_group_var(self, g): return self.group_var_start + self.functional_groups.index(g)
def mw_var(self, t): return self.mw_var_start + self.mw_thresholds.index(min(self.mw_thresholds, key=lambda x: abs(x-t)))
def encode_constraints(self, constraints: List[MolecularConstraint]) -> Tuple[List[List[int]], int]:
all_clauses = self._encode_structural_validity()
all_clauses.extend(self.encode_valence())
# The min_atoms constraint is now handled specially
min_atoms_constraint = next((c for c in constraints if c.constraint_type == 'min_atoms'), None)
if min_atoms_constraint:
all_clauses.extend(self._force_connected_backbone(min_atoms_constraint.value))
for constraint in constraints:
# Skip min_atoms as it's already handled
if constraint.constraint_type != 'min_atoms':
all_clauses.extend(self._encode_single_constraint(constraint))
return self._convert_to_3sat(all_clauses)
# <<< MODIFIED: This is the new, strict connectivity and min_atom enforcer >>>
def _force_connected_backbone(self, min_atoms):
if min_atoms <= 1:
return []
clauses = []
# 1. Force the first `min_atoms` to exist (i.e., not be of type 'None')
for i in range(min_atoms):
clauses.append([self.atom_exists_lit(i)])
# 2. Force a simple path connecting them: 0-1, 1-2, 2-3...
# This guarantees one single connected component of at least size `min_atoms`.
for i in range(min_atoms - 1):
bond_var = self.bond_existence_var(i, i + 1)
clauses.append([bond_var])
# 3. Forbid atoms beyond `min_atoms` from being the *only* other atoms,
# forcing them to connect to the backbone if they exist.
for i in range(min_atoms, self.max_atoms):
# If atom `i` exists, it must be bonded to at least one atom from the backbone
backbone_bonds = [self.bond_existence_var(i, j) for j in range(min_atoms)]
clauses.append([-self.atom_exists_lit(i)] + backbone_bonds)
return clauses
def encode_valence(self):
clauses = []
valence_rules = {'C': 4, 'N': 3, 'O': 2, 'S': 2, 'F': 1, 'Cl': 1, 'Br': 1, 'P': 3, 'H': 1}
for i in range(self.max_atoms):
bond_vars = [self.bond_existence_var(i, j) for j in range(self.max_atoms) if i != j]
for atom_type, val in valence_rules.items():
type_var = self.atom_type_var(i, atom_type)
if val > len(bond_vars):
clauses.append([-type_var])
continue
for cl in self._cardinality_at_least(bond_vars, val) + self._cardinality_at_most(bond_vars, val):
if cl: clauses.append([-type_var] + cl)
return clauses
def _cardinality_at_least(self, V, k):
n = len(V)
if k <= 0: return []
if n < k: return [[1, -1]]
if k == 1 and n > 0: return [V]
clauses = []
s = [[self.var_offset + i * k + j for j in range(k)] for i in range(n)]
self.var_offset += n * k
clauses.append([-V[0], s[0][0]])
for j in range(1, k): clauses.append([-s[0][j]])
for i in range(1, n):
clauses.append([-V[i], s[i][0]])
clauses.append([-s[i-1][0], s[i][0]])
for j in range(1, k):
clauses.append([-V[i], -s[i-1][j-1], s[i][j]])
clauses.append([-s[i-1][j], s[i][j]])
clauses.append([s[n-1][k-1]])
return clauses
def _cardinality_at_most(self, V, k):
n = len(V)
if k < 0: return [[1, -1]]
if k >= n: return []
return self._cardinality_at_least([-v for v in V], n - k)
def _encode_structural_validity(self):
clauses = []
for i in range(self.max_atoms):
v = [self.atom_type_var(i, t) for t in self.atom_types]
clauses.append(v)
for i1 in range(len(v)):
for i2 in range(i1 + 1, len(v)): clauses.append([-v[i1], -v[i2]])
return clauses
def _encode_single_constraint(self, c):
if c.constraint_type == 'aromatic_rings': return self._encode_aromatic_rings(c.value, c.operator)
if c.constraint_type == 'molecular_weight': return self._encode_molecular_weight(c.value, c.operator)
if c.constraint_type == 'forbidden_group': return self._encode_forbidden_group(c.value)
if c.constraint_type == 'synthesizable': return self._encode_synthesizability()
return []
def _encode_aromatic_rings(self, v, o):
if o == '==': return [[self.aromatic_ring_var(i)] if i < v else [-self.aromatic_ring_var(i)] for i in range(self.max_rings)]
return []
def _encode_molecular_weight(self, v, o):
c = []
for i in range(len(self.mw_thresholds) - 1): c.append([-self.mw_var(self.mw_thresholds[i+1]), self.mw_var(self.mw_thresholds[i])])
if o == '<':
for t in self.mw_thresholds:
if t >= v: c.append([-self.mw_var(t)])
return c
def _encode_forbidden_group(self, v):
if v not in self.functional_groups: return []
return [[-self.functional_group_var(v)]]
def _encode_synthesizability(self):
c = [[-self.aromatic_ring_var(i)] for i in range(3, self.max_rings)]
rg = ['nitro', 'azide', 'peroxide', 'isocyanate']
rv = [self.functional_group_var(g) for g in rg if g in self.functional_groups]
for i in range(len(rv)):
for j in range(i + 1, len(rv)): c.append([-rv[i], -rv[j]])
return c
def _convert_to_3sat(self, cs):
s3c, nxt = [], self.var_offset
for c in cs:
if not c: continue
if len(c) <= 3:
while len(c) < 3: c.append(c[-1])
s3c.append(c)
else:
rem = list(c)
while len(rem) > 3:
l1, l2 = rem.pop(0), rem.pop(0)
s3c.append([l1, l2, nxt]); rem.insert(0, -nxt); nxt += 1
s3c.append(rem)
self.var_offset = nxt
return s3c, self.var_offset - 1
def decode_solution(self, a):
s = {'atoms': [], 'bonds': [], 'aromatic_rings': 0, 'functional_groups': [], 'molecular_weight_range': None}
if not isinstance(a, np.ndarray) or a.ndim != 1: return s
existing_atom_ids = set()
for i in range(self.max_atoms):
none_var_idx = self.atom_type_var(i, 'None') - 1
if none_var_idx < len(a) and not a[none_var_idx]:
for t in self.atom_types:
if t == 'None': continue
v = self.atom_type_var(i, t) - 1
if v < len(a) and a[v]:
s['atoms'].append({'id': i, 'element': t})
existing_atom_ids.add(i)
break
for i in range(self.max_atoms):
for j in range(i + 1, self.max_atoms):
v = self.bond_existence_var(i, j)
if v != -1 and v - 1 < len(a) and a[v-1]:
if i in existing_atom_ids and j in existing_atom_ids:
s['bonds'].append({'from': i, 'to': j})
s['aromatic_rings'] = sum(1 for i in range(self.max_rings) if self.aromatic_ring_var(i)-1 < len(a) and a[self.aromatic_ring_var(i)-1])
s['functional_groups'] = [g for g in self.functional_groups if self.functional_group_var(g)-1 < len(a) and a[self.functional_group_var(g)-1]]
mw_min = 0
for t in self.mw_thresholds:
v = self.mw_var(t) - 1
if v < len(a) and a[v]: mw_min = t
else: break
s['molecular_weight_range'] = (mw_min, mw_min + 10)
return s
def parse_constraints(ss):
cs = []
for s in ss:
s = s.strip()
m = re.match(r'(\w+)\s*([<>=!]+)\s*(\d+)', s)
if m:
name, op, val_str = m.groups()
cs.append(MolecularConstraint(name, int(val_str), op))
elif s.startswith('NOT '): cs.append(MolecularConstraint('forbidden_group', s[4:].strip()))
elif s in ['synthesizable']: cs.append(MolecularConstraint(s, True))
return cs |