Aluode commited on
Commit
d1e65ce
·
verified ·
1 Parent(s): 7059c33

Update molecular_constraint_solver.py

Browse files
Files changed (1) hide show
  1. molecular_constraint_solver.py +232 -227
molecular_constraint_solver.py CHANGED
@@ -1,228 +1,233 @@
1
- # molecular_constraint_solver.py
2
- # FINAL VERSION with corrected decoder
3
-
4
- import numpy as np
5
- from typing import List, Dict, Tuple
6
- from dataclasses import dataclass
7
- import re
8
-
9
- @dataclass
10
- class MolecularConstraint:
11
- constraint_type: str
12
- value: any
13
- operator: str = '=='
14
-
15
- class MolecularConstraintEncoder:
16
- def __init__(self, max_atoms=30):
17
- self.max_atoms = max_atoms
18
- self.max_bonds = max_atoms * (max_atoms - 1) // 2
19
- self.var_offset = 1
20
- self.atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'P', 'H', 'None']
21
- self.atom_var_start = self.var_offset
22
- self.var_offset += self.max_atoms * len(self.atom_types)
23
- self.bond_existence_var_start = self.var_offset
24
- self.var_offset += self.max_bonds
25
- self.conn_var_start = self.var_offset
26
- self.var_offset += self.max_atoms
27
- self.bond_types = ['single', 'double', 'triple']
28
- self.bond_type_var_start = self.var_offset
29
- self.var_offset += self.max_bonds * len(self.bond_types)
30
- self.ring_var_start = self.var_offset
31
- self.var_offset += self.max_atoms
32
- self.max_rings = 10
33
- self.aromatic_var_start = self.var_offset
34
- self.var_offset += self.max_rings
35
- self.functional_groups = ['nitro', 'azide', 'peroxide', 'aldehyde', 'ketone', 'carboxyl', 'amine', 'amide', 'ester', 'ether', 'thiol', 'sulfone', 'phosphate', 'hydroxyl', 'halogen', 'cyano', 'isocyanate', 'epoxide', 'lactone', 'quinone']
36
- self.group_var_start = self.var_offset
37
- self.var_offset += len(self.functional_groups)
38
- self.mw_thresholds = list(range(100, 600, 10))
39
- self.mw_var_start = self.var_offset
40
- self.var_offset += len(self.mw_thresholds)
41
-
42
- def atom_type_var(self, atom_idx, atom_type):
43
- return self.atom_var_start + atom_idx * len(self.atom_types) + self.atom_types.index(atom_type)
44
-
45
- def bond_existence_var(self, i, j):
46
- if i == j: return -1
47
- if i > j: i, j = j, i
48
- idx = int(i * (self.max_atoms - (i + 1) / 2.0) + (j - i - 1))
49
- return self.bond_existence_var_start + idx
50
-
51
- def conn_var(self, atom_idx):
52
- return self.conn_var_start + atom_idx
53
-
54
- def atom_exists_lit(self, atom_idx):
55
- return -self.atom_type_var(atom_idx, 'None')
56
-
57
- def ring_var(self, idx): return self.ring_var_start + idx
58
- def aromatic_ring_var(self, idx): return self.aromatic_var_start + idx
59
- def functional_group_var(self, g): return self.group_var_start + self.functional_groups.index(g)
60
- def mw_var(self, t): return self.mw_var_start + self.mw_thresholds.index(min(self.mw_thresholds, key=lambda x: abs(x-t)))
61
-
62
- def encode_constraints(self, constraints: List[MolecularConstraint]) -> Tuple[List[List[int]], int]:
63
- all_clauses = self._encode_structural_validity()
64
- all_clauses.extend(self.encode_valence())
65
- all_clauses.extend(self._encode_connectivity())
66
- for constraint in constraints:
67
- all_clauses.extend(self._encode_single_constraint(constraint))
68
- return self._convert_to_3sat(all_clauses)
69
-
70
- def _encode_connectivity(self):
71
- clauses = []
72
- clauses.append([self.atom_type_var(0, 'None'), self.conn_var(0)])
73
- clauses.append([-self.atom_type_var(0, 'None'), -self.conn_var(0)])
74
- for i in range(self.max_atoms):
75
- for j in range(i + 1, self.max_atoms):
76
- bond_var = self.bond_existence_var(i, j)
77
- clauses.append([-self.conn_var(i), -bond_var, self.conn_var(j)])
78
- clauses.append([-self.conn_var(j), -bond_var, self.conn_var(i)])
79
- for i in range(self.max_atoms):
80
- clauses.append([self.atom_type_var(i, 'None'), self.conn_var(i)])
81
- return clauses
82
-
83
- def encode_valence(self):
84
- clauses = []
85
- valence_rules = {'C': 4, 'N': 3, 'O': 2, 'S': 2, 'F': 1, 'Cl': 1, 'Br': 1, 'P': 3, 'H': 1}
86
- for i in range(self.max_atoms):
87
- bond_vars = [self.bond_existence_var(i, j) for j in range(self.max_atoms) if i != j]
88
- for atom_type, val in valence_rules.items():
89
- type_var = self.atom_type_var(i, atom_type)
90
- if val > len(bond_vars):
91
- clauses.append([-type_var])
92
- continue
93
- for cl in self._cardinality_at_least(bond_vars, val) + self._cardinality_at_most(bond_vars, val):
94
- if cl: clauses.append([-type_var] + cl)
95
- return clauses
96
-
97
- def _cardinality_at_least(self, V, k):
98
- n = len(V)
99
- if k <= 0: return []
100
- if n < k: return [[1, -1]]
101
- if k == 1 and n > 0: return [V]
102
- clauses = []
103
- s = [[self.var_offset + i * k + j for j in range(k)] for i in range(n)]
104
- self.var_offset += n * k
105
- clauses.append([-V[0], s[0][0]])
106
- for j in range(1, k): clauses.append([-s[0][j]])
107
- for i in range(1, n):
108
- clauses.append([-V[i], s[i][0]])
109
- clauses.append([-s[i-1][0], s[i][0]])
110
- for j in range(1, k):
111
- clauses.append([-V[i], -s[i-1][j-1], s[i][j]])
112
- clauses.append([-s[i-1][j], s[i][j]])
113
- clauses.append([s[n-1][k-1]])
114
- return clauses
115
-
116
- def _cardinality_at_most(self, V, k):
117
- n = len(V)
118
- if k < 0: return [[1, -1]]
119
- if k >= n: return []
120
- return self._cardinality_at_least([-v for v in V], n - k)
121
-
122
- def _encode_structural_validity(self):
123
- clauses = []
124
- for i in range(self.max_atoms):
125
- v = [self.atom_type_var(i, t) for t in self.atom_types]
126
- clauses.append(v)
127
- for i1 in range(len(v)):
128
- for i2 in range(i1 + 1, len(v)): clauses.append([-v[i1], -v[i2]])
129
- return clauses
130
-
131
- def _encode_single_constraint(self, c):
132
- if c.constraint_type == 'min_atoms': return self._encode_min_atoms(c.value)
133
- if c.constraint_type == 'aromatic_rings': return self._encode_aromatic_rings(c.value, c.operator)
134
- if c.constraint_type == 'molecular_weight': return self._encode_molecular_weight(c.value, c.operator)
135
- if c.constraint_type == 'forbidden_group': return self._encode_forbidden_group(c.value)
136
- if c.constraint_type == 'synthesizable': return self._encode_synthesizability()
137
- return []
138
-
139
- def _encode_min_atoms(self, k):
140
- if k <= 0: return []
141
- existence_literals = [self.atom_exists_lit(i) for i in range(self.max_atoms)]
142
- return self._cardinality_at_least(existence_literals, k)
143
-
144
- def _encode_aromatic_rings(self, v, o):
145
- if o == '==': return [[self.aromatic_ring_var(i)] if i < v else [-self.aromatic_ring_var(i)] for i in range(self.max_rings)]
146
- return []
147
-
148
- def _encode_molecular_weight(self, v, o):
149
- c = []
150
- for i in range(len(self.mw_thresholds) - 1): c.append([-self.mw_var(self.mw_thresholds[i+1]), self.mw_var(self.mw_thresholds[i])])
151
- if o == '<':
152
- for t in self.mw_thresholds:
153
- if t >= v: c.append([-self.mw_var(t)])
154
- return c
155
-
156
- def _encode_forbidden_group(self, v):
157
- if v not in self.functional_groups: return []
158
- return [[-self.functional_group_var(v)]]
159
-
160
- def _encode_synthesizability(self):
161
- c = [[-self.aromatic_ring_var(i)] for i in range(3, self.max_rings)]
162
- rg = ['nitro', 'azide', 'peroxide', 'isocyanate']
163
- rv = [self.functional_group_var(g) for g in rg if g in self.functional_groups]
164
- for i in range(len(rv)):
165
- for j in range(i + 1, len(rv)): c.append([-rv[i], -rv[j]])
166
- return c
167
-
168
- def _convert_to_3sat(self, cs):
169
- s3c, nxt = [], self.var_offset
170
- for c in cs:
171
- if not c: continue
172
- if len(c) <= 3:
173
- while len(c) < 3: c.append(c[-1])
174
- s3c.append(c)
175
- else:
176
- rem = list(c)
177
- while len(rem) > 3:
178
- l1, l2 = rem.pop(0), rem.pop(0)
179
- s3c.append([l1, l2, nxt]); rem.insert(0, -nxt); nxt += 1
180
- s3c.append(rem)
181
- self.var_offset = nxt
182
- return s3c, self.var_offset - 1
183
-
184
- # <<< MODIFIED: Robust decoder to prevent ghost bonds >>>
185
- def decode_solution(self, a):
186
- s = {'atoms': [], 'bonds': [], 'aromatic_rings': 0, 'functional_groups': [], 'molecular_weight_range': None}
187
- if not isinstance(a, np.ndarray) or a.ndim != 1: return s
188
-
189
- # Step 1: Decode atoms and create a set of valid, existing atom IDs
190
- existing_atom_ids = set()
191
- for i in range(self.max_atoms):
192
- for t in self.atom_types:
193
- v = self.atom_type_var(i, t) - 1
194
- if v < len(a) and a[v] and t != 'None':
195
- s['atoms'].append({'id': i, 'element': t})
196
- existing_atom_ids.add(i)
197
- break
198
-
199
- # Step 2: Decode bonds, but only if BOTH atoms in the bond exist
200
- for i in range(self.max_atoms):
201
- for j in range(i + 1, self.max_atoms):
202
- v = self.bond_existence_var(i, j)
203
- if v != -1 and v - 1 < len(a) and a[v-1]:
204
- # FIX: Check if both atoms are in our set of existing atoms
205
- if i in existing_atom_ids and j in existing_atom_ids:
206
- s['bonds'].append({'from': i, 'to': j})
207
-
208
- s['aromatic_rings'] = sum(1 for i in range(self.max_rings) if self.aromatic_ring_var(i)-1 < len(a) and a[self.aromatic_ring_var(i)-1])
209
- s['functional_groups'] = [g for g in self.functional_groups if self.functional_group_var(g)-1 < len(a) and a[self.functional_group_var(g)-1]]
210
- mw_min = 0
211
- for t in self.mw_thresholds:
212
- v = self.mw_var(t) - 1
213
- if v < len(a) and a[v]: mw_min = t
214
- else: break
215
- s['molecular_weight_range'] = (mw_min, mw_min + 10)
216
- return s
217
-
218
- def parse_constraints(ss):
219
- cs = []
220
- for s in ss:
221
- s = s.strip()
222
- m = re.match(r'(\w+)\s*([<>=!]+)\s*(\d+)', s)
223
- if m:
224
- name, op, val_str = m.groups()
225
- cs.append(MolecularConstraint(name, int(val_str), op))
226
- elif s.startswith('NOT '): cs.append(MolecularConstraint('forbidden_group', s[4:].strip()))
227
- elif s in ['synthesizable']: cs.append(MolecularConstraint(s, True))
 
 
 
 
 
228
  return cs
 
1
+ # molecular_constraint_solver.py
2
+ # FINAL VERSION with strict connectivity scaffolding
3
+
4
+ import numpy as np
5
+ from typing import List, Dict, Tuple
6
+ from dataclasses import dataclass
7
+ import re
8
+
9
+ @dataclass
10
+ class MolecularConstraint:
11
+ constraint_type: str
12
+ value: any
13
+ operator: str = '=='
14
+
15
+ class MolecularConstraintEncoder:
16
+ def __init__(self, max_atoms=30):
17
+ self.max_atoms = max_atoms
18
+ self.max_bonds = max_atoms * (max_atoms - 1) // 2
19
+ self.var_offset = 1
20
+ self.atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'P', 'H', 'None']
21
+ self.atom_var_start = self.var_offset
22
+ self.var_offset += self.max_atoms * len(self.atom_types)
23
+ self.bond_existence_var_start = self.var_offset
24
+ self.var_offset += self.max_bonds
25
+ self.bond_types = ['single', 'double', 'triple']
26
+ self.bond_type_var_start = self.var_offset
27
+ self.var_offset += self.max_bonds * len(self.bond_types)
28
+ self.ring_var_start = self.var_offset
29
+ self.var_offset += self.max_atoms
30
+ self.max_rings = 10
31
+ self.aromatic_var_start = self.var_offset
32
+ self.var_offset += self.max_rings
33
+ self.functional_groups = ['nitro', 'azide', 'peroxide', 'aldehyde', 'ketone', 'carboxyl', 'amine', 'amide', 'ester', 'ether', 'thiol', 'sulfone', 'phosphate', 'hydroxyl', 'halogen', 'cyano', 'isocyanate', 'epoxide', 'lactone', 'quinone']
34
+ self.group_var_start = self.var_offset
35
+ self.var_offset += len(self.functional_groups)
36
+ self.mw_thresholds = list(range(100, 600, 10))
37
+ self.mw_var_start = self.var_offset
38
+ self.var_offset += len(self.mw_thresholds)
39
+
40
+ def atom_type_var(self, atom_idx, atom_type):
41
+ return self.atom_var_start + atom_idx * len(self.atom_types) + self.atom_types.index(atom_type)
42
+
43
+ def bond_existence_var(self, i, j):
44
+ if i == j: return -1
45
+ if i > j: i, j = j, i
46
+ idx = int(i * (self.max_atoms - (i + 1) / 2.0) + (j - i - 1))
47
+ return self.bond_existence_var_start + idx
48
+
49
+ def atom_exists_lit(self, atom_idx):
50
+ return -self.atom_type_var(atom_idx, 'None')
51
+
52
+ def ring_var(self, idx): return self.ring_var_start + idx
53
+ def aromatic_ring_var(self, idx): return self.aromatic_var_start + idx
54
+ def functional_group_var(self, g): return self.group_var_start + self.functional_groups.index(g)
55
+ def mw_var(self, t): return self.mw_var_start + self.mw_thresholds.index(min(self.mw_thresholds, key=lambda x: abs(x-t)))
56
+
57
+ def encode_constraints(self, constraints: List[MolecularConstraint]) -> Tuple[List[List[int]], int]:
58
+ all_clauses = self._encode_structural_validity()
59
+ all_clauses.extend(self.encode_valence())
60
+
61
+ # The min_atoms constraint is now handled specially
62
+ min_atoms_constraint = next((c for c in constraints if c.constraint_type == 'min_atoms'), None)
63
+ if min_atoms_constraint:
64
+ all_clauses.extend(self._force_connected_backbone(min_atoms_constraint.value))
65
+
66
+ for constraint in constraints:
67
+ # Skip min_atoms as it's already handled
68
+ if constraint.constraint_type != 'min_atoms':
69
+ all_clauses.extend(self._encode_single_constraint(constraint))
70
+
71
+ return self._convert_to_3sat(all_clauses)
72
+
73
+ # <<< MODIFIED: This is the new, strict connectivity and min_atom enforcer >>>
74
+ def _force_connected_backbone(self, min_atoms):
75
+ if min_atoms <= 1:
76
+ return []
77
+
78
+ clauses = []
79
+ # 1. Force the first `min_atoms` to exist (i.e., not be of type 'None')
80
+ for i in range(min_atoms):
81
+ clauses.append([self.atom_exists_lit(i)])
82
+
83
+ # 2. Force a simple path connecting them: 0-1, 1-2, 2-3...
84
+ # This guarantees one single connected component of at least size `min_atoms`.
85
+ for i in range(min_atoms - 1):
86
+ bond_var = self.bond_existence_var(i, i + 1)
87
+ clauses.append([bond_var])
88
+
89
+ # 3. Forbid atoms beyond `min_atoms` from being the *only* other atoms,
90
+ # forcing them to connect to the backbone if they exist.
91
+ for i in range(min_atoms, self.max_atoms):
92
+ # If atom `i` exists, it must be bonded to at least one atom from the backbone
93
+ backbone_bonds = [self.bond_existence_var(i, j) for j in range(min_atoms)]
94
+ clauses.append([-self.atom_exists_lit(i)] + backbone_bonds)
95
+
96
+ return clauses
97
+
98
+ def encode_valence(self):
99
+ clauses = []
100
+ valence_rules = {'C': 4, 'N': 3, 'O': 2, 'S': 2, 'F': 1, 'Cl': 1, 'Br': 1, 'P': 3, 'H': 1}
101
+ for i in range(self.max_atoms):
102
+ bond_vars = [self.bond_existence_var(i, j) for j in range(self.max_atoms) if i != j]
103
+ for atom_type, val in valence_rules.items():
104
+ type_var = self.atom_type_var(i, atom_type)
105
+ if val > len(bond_vars):
106
+ clauses.append([-type_var])
107
+ continue
108
+ for cl in self._cardinality_at_least(bond_vars, val) + self._cardinality_at_most(bond_vars, val):
109
+ if cl: clauses.append([-type_var] + cl)
110
+ return clauses
111
+
112
+ def _cardinality_at_least(self, V, k):
113
+ n = len(V)
114
+ if k <= 0: return []
115
+ if n < k: return [[1, -1]]
116
+ if k == 1 and n > 0: return [V]
117
+ clauses = []
118
+ s = [[self.var_offset + i * k + j for j in range(k)] for i in range(n)]
119
+ self.var_offset += n * k
120
+ clauses.append([-V[0], s[0][0]])
121
+ for j in range(1, k): clauses.append([-s[0][j]])
122
+ for i in range(1, n):
123
+ clauses.append([-V[i], s[i][0]])
124
+ clauses.append([-s[i-1][0], s[i][0]])
125
+ for j in range(1, k):
126
+ clauses.append([-V[i], -s[i-1][j-1], s[i][j]])
127
+ clauses.append([-s[i-1][j], s[i][j]])
128
+ clauses.append([s[n-1][k-1]])
129
+ return clauses
130
+
131
+ def _cardinality_at_most(self, V, k):
132
+ n = len(V)
133
+ if k < 0: return [[1, -1]]
134
+ if k >= n: return []
135
+ return self._cardinality_at_least([-v for v in V], n - k)
136
+
137
+ def _encode_structural_validity(self):
138
+ clauses = []
139
+ for i in range(self.max_atoms):
140
+ v = [self.atom_type_var(i, t) for t in self.atom_types]
141
+ clauses.append(v)
142
+ for i1 in range(len(v)):
143
+ for i2 in range(i1 + 1, len(v)): clauses.append([-v[i1], -v[i2]])
144
+ return clauses
145
+
146
+ def _encode_single_constraint(self, c):
147
+ if c.constraint_type == 'aromatic_rings': return self._encode_aromatic_rings(c.value, c.operator)
148
+ if c.constraint_type == 'molecular_weight': return self._encode_molecular_weight(c.value, c.operator)
149
+ if c.constraint_type == 'forbidden_group': return self._encode_forbidden_group(c.value)
150
+ if c.constraint_type == 'synthesizable': return self._encode_synthesizability()
151
+ return []
152
+
153
+ def _encode_aromatic_rings(self, v, o):
154
+ if o == '==': return [[self.aromatic_ring_var(i)] if i < v else [-self.aromatic_ring_var(i)] for i in range(self.max_rings)]
155
+ return []
156
+
157
+ def _encode_molecular_weight(self, v, o):
158
+ c = []
159
+ for i in range(len(self.mw_thresholds) - 1): c.append([-self.mw_var(self.mw_thresholds[i+1]), self.mw_var(self.mw_thresholds[i])])
160
+ if o == '<':
161
+ for t in self.mw_thresholds:
162
+ if t >= v: c.append([-self.mw_var(t)])
163
+ return c
164
+
165
+ def _encode_forbidden_group(self, v):
166
+ if v not in self.functional_groups: return []
167
+ return [[-self.functional_group_var(v)]]
168
+
169
+ def _encode_synthesizability(self):
170
+ c = [[-self.aromatic_ring_var(i)] for i in range(3, self.max_rings)]
171
+ rg = ['nitro', 'azide', 'peroxide', 'isocyanate']
172
+ rv = [self.functional_group_var(g) for g in rg if g in self.functional_groups]
173
+ for i in range(len(rv)):
174
+ for j in range(i + 1, len(rv)): c.append([-rv[i], -rv[j]])
175
+ return c
176
+
177
+ def _convert_to_3sat(self, cs):
178
+ s3c, nxt = [], self.var_offset
179
+ for c in cs:
180
+ if not c: continue
181
+ if len(c) <= 3:
182
+ while len(c) < 3: c.append(c[-1])
183
+ s3c.append(c)
184
+ else:
185
+ rem = list(c)
186
+ while len(rem) > 3:
187
+ l1, l2 = rem.pop(0), rem.pop(0)
188
+ s3c.append([l1, l2, nxt]); rem.insert(0, -nxt); nxt += 1
189
+ s3c.append(rem)
190
+ self.var_offset = nxt
191
+ return s3c, self.var_offset - 1
192
+
193
+ def decode_solution(self, a):
194
+ s = {'atoms': [], 'bonds': [], 'aromatic_rings': 0, 'functional_groups': [], 'molecular_weight_range': None}
195
+ if not isinstance(a, np.ndarray) or a.ndim != 1: return s
196
+ existing_atom_ids = set()
197
+ for i in range(self.max_atoms):
198
+ none_var_idx = self.atom_type_var(i, 'None') - 1
199
+ if none_var_idx < len(a) and not a[none_var_idx]:
200
+ for t in self.atom_types:
201
+ if t == 'None': continue
202
+ v = self.atom_type_var(i, t) - 1
203
+ if v < len(a) and a[v]:
204
+ s['atoms'].append({'id': i, 'element': t})
205
+ existing_atom_ids.add(i)
206
+ break
207
+ for i in range(self.max_atoms):
208
+ for j in range(i + 1, self.max_atoms):
209
+ v = self.bond_existence_var(i, j)
210
+ if v != -1 and v - 1 < len(a) and a[v-1]:
211
+ if i in existing_atom_ids and j in existing_atom_ids:
212
+ s['bonds'].append({'from': i, 'to': j})
213
+ s['aromatic_rings'] = sum(1 for i in range(self.max_rings) if self.aromatic_ring_var(i)-1 < len(a) and a[self.aromatic_ring_var(i)-1])
214
+ s['functional_groups'] = [g for g in self.functional_groups if self.functional_group_var(g)-1 < len(a) and a[self.functional_group_var(g)-1]]
215
+ mw_min = 0
216
+ for t in self.mw_thresholds:
217
+ v = self.mw_var(t) - 1
218
+ if v < len(a) and a[v]: mw_min = t
219
+ else: break
220
+ s['molecular_weight_range'] = (mw_min, mw_min + 10)
221
+ return s
222
+
223
+ def parse_constraints(ss):
224
+ cs = []
225
+ for s in ss:
226
+ s = s.strip()
227
+ m = re.match(r'(\w+)\s*([<>=!]+)\s*(\d+)', s)
228
+ if m:
229
+ name, op, val_str = m.groups()
230
+ cs.append(MolecularConstraint(name, int(val_str), op))
231
+ elif s.startswith('NOT '): cs.append(MolecularConstraint('forbidden_group', s[4:].strip()))
232
+ elif s in ['synthesizable']: cs.append(MolecularConstraint(s, True))
233
  return cs