File size: 10,596 Bytes
d1e65ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f19e3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# molecular_constraint_solver.py
# FINAL VERSION with strict connectivity scaffolding

import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass
import re

@dataclass
class MolecularConstraint:
    constraint_type: str
    value: any
    operator: str = '=='

class MolecularConstraintEncoder:
    def __init__(self, max_atoms=30):
        self.max_atoms = max_atoms
        self.max_bonds = max_atoms * (max_atoms - 1) // 2
        self.var_offset = 1
        self.atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'P', 'H', 'None']
        self.atom_var_start = self.var_offset
        self.var_offset += self.max_atoms * len(self.atom_types)
        self.bond_existence_var_start = self.var_offset
        self.var_offset += self.max_bonds
        self.bond_types = ['single', 'double', 'triple']
        self.bond_type_var_start = self.var_offset
        self.var_offset += self.max_bonds * len(self.bond_types)
        self.ring_var_start = self.var_offset
        self.var_offset += self.max_atoms
        self.max_rings = 10
        self.aromatic_var_start = self.var_offset
        self.var_offset += self.max_rings
        self.functional_groups = ['nitro', 'azide', 'peroxide', 'aldehyde', 'ketone', 'carboxyl', 'amine', 'amide', 'ester', 'ether', 'thiol', 'sulfone', 'phosphate', 'hydroxyl', 'halogen', 'cyano', 'isocyanate', 'epoxide', 'lactone', 'quinone']
        self.group_var_start = self.var_offset
        self.var_offset += len(self.functional_groups)
        self.mw_thresholds = list(range(100, 600, 10))
        self.mw_var_start = self.var_offset
        self.var_offset += len(self.mw_thresholds)

    def atom_type_var(self, atom_idx, atom_type):
        return self.atom_var_start + atom_idx * len(self.atom_types) + self.atom_types.index(atom_type)

    def bond_existence_var(self, i, j):
        if i == j: return -1
        if i > j: i, j = j, i
        idx = int(i * (self.max_atoms - (i + 1) / 2.0) + (j - i - 1))
        return self.bond_existence_var_start + idx

    def atom_exists_lit(self, atom_idx):
        return -self.atom_type_var(atom_idx, 'None')

    def ring_var(self, idx): return self.ring_var_start + idx
    def aromatic_ring_var(self, idx): return self.aromatic_var_start + idx
    def functional_group_var(self, g): return self.group_var_start + self.functional_groups.index(g)
    def mw_var(self, t): return self.mw_var_start + self.mw_thresholds.index(min(self.mw_thresholds, key=lambda x: abs(x-t)))

    def encode_constraints(self, constraints: List[MolecularConstraint]) -> Tuple[List[List[int]], int]:
        all_clauses = self._encode_structural_validity()
        all_clauses.extend(self.encode_valence())
        
        # The min_atoms constraint is now handled specially
        min_atoms_constraint = next((c for c in constraints if c.constraint_type == 'min_atoms'), None)
        if min_atoms_constraint:
            all_clauses.extend(self._force_connected_backbone(min_atoms_constraint.value))

        for constraint in constraints:
            # Skip min_atoms as it's already handled
            if constraint.constraint_type != 'min_atoms':
                all_clauses.extend(self._encode_single_constraint(constraint))
        
        return self._convert_to_3sat(all_clauses)

    # <<< MODIFIED: This is the new, strict connectivity and min_atom enforcer >>>
    def _force_connected_backbone(self, min_atoms):
        if min_atoms <= 1:
            return []
        
        clauses = []
        # 1. Force the first `min_atoms` to exist (i.e., not be of type 'None')
        for i in range(min_atoms):
            clauses.append([self.atom_exists_lit(i)])

        # 2. Force a simple path connecting them: 0-1, 1-2, 2-3...
        # This guarantees one single connected component of at least size `min_atoms`.
        for i in range(min_atoms - 1):
            bond_var = self.bond_existence_var(i, i + 1)
            clauses.append([bond_var])
        
        # 3. Forbid atoms beyond `min_atoms` from being the *only* other atoms,
        # forcing them to connect to the backbone if they exist.
        for i in range(min_atoms, self.max_atoms):
            # If atom `i` exists, it must be bonded to at least one atom from the backbone
            backbone_bonds = [self.bond_existence_var(i, j) for j in range(min_atoms)]
            clauses.append([-self.atom_exists_lit(i)] + backbone_bonds)

        return clauses

    def encode_valence(self):
        clauses = []
        valence_rules = {'C': 4, 'N': 3, 'O': 2, 'S': 2, 'F': 1, 'Cl': 1, 'Br': 1, 'P': 3, 'H': 1}
        for i in range(self.max_atoms):
            bond_vars = [self.bond_existence_var(i, j) for j in range(self.max_atoms) if i != j]
            for atom_type, val in valence_rules.items():
                type_var = self.atom_type_var(i, atom_type)
                if val > len(bond_vars):
                    clauses.append([-type_var])
                    continue
                for cl in self._cardinality_at_least(bond_vars, val) + self._cardinality_at_most(bond_vars, val):
                    if cl: clauses.append([-type_var] + cl)
        return clauses

    def _cardinality_at_least(self, V, k):
        n = len(V)
        if k <= 0: return []
        if n < k: return [[1, -1]]
        if k == 1 and n > 0: return [V]
        clauses = []
        s = [[self.var_offset + i * k + j for j in range(k)] for i in range(n)]
        self.var_offset += n * k
        clauses.append([-V[0], s[0][0]])
        for j in range(1, k): clauses.append([-s[0][j]])
        for i in range(1, n):
            clauses.append([-V[i], s[i][0]])
            clauses.append([-s[i-1][0], s[i][0]])
            for j in range(1, k):
                clauses.append([-V[i], -s[i-1][j-1], s[i][j]])
                clauses.append([-s[i-1][j], s[i][j]])
        clauses.append([s[n-1][k-1]])
        return clauses

    def _cardinality_at_most(self, V, k):
        n = len(V)
        if k < 0: return [[1, -1]]
        if k >= n: return []
        return self._cardinality_at_least([-v for v in V], n - k)
    
    def _encode_structural_validity(self):
        clauses = []
        for i in range(self.max_atoms):
            v = [self.atom_type_var(i, t) for t in self.atom_types]
            clauses.append(v)
            for i1 in range(len(v)):
                for i2 in range(i1 + 1, len(v)): clauses.append([-v[i1], -v[i2]])
        return clauses
    
    def _encode_single_constraint(self, c):
        if c.constraint_type == 'aromatic_rings': return self._encode_aromatic_rings(c.value, c.operator)
        if c.constraint_type == 'molecular_weight': return self._encode_molecular_weight(c.value, c.operator)
        if c.constraint_type == 'forbidden_group': return self._encode_forbidden_group(c.value)
        if c.constraint_type == 'synthesizable': return self._encode_synthesizability()
        return []
    
    def _encode_aromatic_rings(self, v, o):
        if o == '==': return [[self.aromatic_ring_var(i)] if i < v else [-self.aromatic_ring_var(i)] for i in range(self.max_rings)]
        return []
    
    def _encode_molecular_weight(self, v, o):
        c = []
        for i in range(len(self.mw_thresholds) - 1): c.append([-self.mw_var(self.mw_thresholds[i+1]), self.mw_var(self.mw_thresholds[i])])
        if o == '<':
            for t in self.mw_thresholds:
                if t >= v: c.append([-self.mw_var(t)])
        return c

    def _encode_forbidden_group(self, v):
        if v not in self.functional_groups: return []
        return [[-self.functional_group_var(v)]]

    def _encode_synthesizability(self):
        c = [[-self.aromatic_ring_var(i)] for i in range(3, self.max_rings)]
        rg = ['nitro', 'azide', 'peroxide', 'isocyanate']
        rv = [self.functional_group_var(g) for g in rg if g in self.functional_groups]
        for i in range(len(rv)):
            for j in range(i + 1, len(rv)): c.append([-rv[i], -rv[j]])
        return c
    
    def _convert_to_3sat(self, cs):
        s3c, nxt = [], self.var_offset
        for c in cs:
            if not c: continue
            if len(c) <= 3:
                while len(c) < 3: c.append(c[-1])
                s3c.append(c)
            else:
                rem = list(c)
                while len(rem) > 3:
                    l1, l2 = rem.pop(0), rem.pop(0)
                    s3c.append([l1, l2, nxt]); rem.insert(0, -nxt); nxt += 1
                s3c.append(rem)
        self.var_offset = nxt
        return s3c, self.var_offset - 1

    def decode_solution(self, a):
        s = {'atoms': [], 'bonds': [], 'aromatic_rings': 0, 'functional_groups': [], 'molecular_weight_range': None}
        if not isinstance(a, np.ndarray) or a.ndim != 1: return s
        existing_atom_ids = set()
        for i in range(self.max_atoms):
            none_var_idx = self.atom_type_var(i, 'None') - 1
            if none_var_idx < len(a) and not a[none_var_idx]:
                for t in self.atom_types:
                    if t == 'None': continue
                    v = self.atom_type_var(i, t) - 1
                    if v < len(a) and a[v]:
                        s['atoms'].append({'id': i, 'element': t})
                        existing_atom_ids.add(i)
                        break
        for i in range(self.max_atoms):
            for j in range(i + 1, self.max_atoms):
                 v = self.bond_existence_var(i, j)
                 if v != -1 and v - 1 < len(a) and a[v-1]:
                     if i in existing_atom_ids and j in existing_atom_ids:
                         s['bonds'].append({'from': i, 'to': j})
        s['aromatic_rings'] = sum(1 for i in range(self.max_rings) if self.aromatic_ring_var(i)-1 < len(a) and a[self.aromatic_ring_var(i)-1])
        s['functional_groups'] = [g for g in self.functional_groups if self.functional_group_var(g)-1 < len(a) and a[self.functional_group_var(g)-1]]
        mw_min = 0
        for t in self.mw_thresholds:
            v = self.mw_var(t) - 1
            if v < len(a) and a[v]: mw_min = t
            else: break
        s['molecular_weight_range'] = (mw_min, mw_min + 10)
        return s

def parse_constraints(ss):
    cs = []
    for s in ss:
        s = s.strip()
        m = re.match(r'(\w+)\s*([<>=!]+)\s*(\d+)', s)
        if m:
            name, op, val_str = m.groups()
            cs.append(MolecularConstraint(name, int(val_str), op))
        elif s.startswith('NOT '): cs.append(MolecularConstraint('forbidden_group', s[4:].strip()))
        elif s in ['synthesizable']: cs.append(MolecularConstraint(s, True))
    return cs