|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from typing import List, Dict, Tuple |
|
|
from dataclasses import dataclass |
|
|
import re |
|
|
|
|
|
@dataclass |
|
|
class MolecularConstraint: |
|
|
constraint_type: str |
|
|
value: any |
|
|
operator: str = '==' |
|
|
|
|
|
class MolecularConstraintEncoder: |
|
|
def __init__(self, max_atoms=30): |
|
|
self.max_atoms = max_atoms |
|
|
self.max_bonds = max_atoms * (max_atoms - 1) // 2 |
|
|
self.var_offset = 1 |
|
|
self.atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'P', 'H', 'None'] |
|
|
self.atom_var_start = self.var_offset |
|
|
self.var_offset += self.max_atoms * len(self.atom_types) |
|
|
self.bond_existence_var_start = self.var_offset |
|
|
self.var_offset += self.max_bonds |
|
|
self.bond_types = ['single', 'double', 'triple'] |
|
|
self.bond_type_var_start = self.var_offset |
|
|
self.var_offset += self.max_bonds * len(self.bond_types) |
|
|
self.ring_var_start = self.var_offset |
|
|
self.var_offset += self.max_atoms |
|
|
self.max_rings = 10 |
|
|
self.aromatic_var_start = self.var_offset |
|
|
self.var_offset += self.max_rings |
|
|
self.functional_groups = ['nitro', 'azide', 'peroxide', 'aldehyde', 'ketone', 'carboxyl', 'amine', 'amide', 'ester', 'ether', 'thiol', 'sulfone', 'phosphate', 'hydroxyl', 'halogen', 'cyano', 'isocyanate', 'epoxide', 'lactone', 'quinone'] |
|
|
self.group_var_start = self.var_offset |
|
|
self.var_offset += len(self.functional_groups) |
|
|
self.mw_thresholds = list(range(100, 600, 10)) |
|
|
self.mw_var_start = self.var_offset |
|
|
self.var_offset += len(self.mw_thresholds) |
|
|
|
|
|
def atom_type_var(self, atom_idx, atom_type): |
|
|
return self.atom_var_start + atom_idx * len(self.atom_types) + self.atom_types.index(atom_type) |
|
|
|
|
|
def bond_existence_var(self, i, j): |
|
|
if i == j: return -1 |
|
|
if i > j: i, j = j, i |
|
|
idx = int(i * (self.max_atoms - (i + 1) / 2.0) + (j - i - 1)) |
|
|
return self.bond_existence_var_start + idx |
|
|
|
|
|
def atom_exists_lit(self, atom_idx): |
|
|
return -self.atom_type_var(atom_idx, 'None') |
|
|
|
|
|
def ring_var(self, idx): return self.ring_var_start + idx |
|
|
def aromatic_ring_var(self, idx): return self.aromatic_var_start + idx |
|
|
def functional_group_var(self, g): return self.group_var_start + self.functional_groups.index(g) |
|
|
def mw_var(self, t): return self.mw_var_start + self.mw_thresholds.index(min(self.mw_thresholds, key=lambda x: abs(x-t))) |
|
|
|
|
|
def encode_constraints(self, constraints: List[MolecularConstraint]) -> Tuple[List[List[int]], int]: |
|
|
all_clauses = self._encode_structural_validity() |
|
|
all_clauses.extend(self.encode_valence()) |
|
|
|
|
|
|
|
|
min_atoms_constraint = next((c for c in constraints if c.constraint_type == 'min_atoms'), None) |
|
|
if min_atoms_constraint: |
|
|
all_clauses.extend(self._force_connected_backbone(min_atoms_constraint.value)) |
|
|
|
|
|
for constraint in constraints: |
|
|
|
|
|
if constraint.constraint_type != 'min_atoms': |
|
|
all_clauses.extend(self._encode_single_constraint(constraint)) |
|
|
|
|
|
return self._convert_to_3sat(all_clauses) |
|
|
|
|
|
|
|
|
def _force_connected_backbone(self, min_atoms): |
|
|
if min_atoms <= 1: |
|
|
return [] |
|
|
|
|
|
clauses = [] |
|
|
|
|
|
for i in range(min_atoms): |
|
|
clauses.append([self.atom_exists_lit(i)]) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(min_atoms - 1): |
|
|
bond_var = self.bond_existence_var(i, i + 1) |
|
|
clauses.append([bond_var]) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(min_atoms, self.max_atoms): |
|
|
|
|
|
backbone_bonds = [self.bond_existence_var(i, j) for j in range(min_atoms)] |
|
|
clauses.append([-self.atom_exists_lit(i)] + backbone_bonds) |
|
|
|
|
|
return clauses |
|
|
|
|
|
def encode_valence(self): |
|
|
clauses = [] |
|
|
valence_rules = {'C': 4, 'N': 3, 'O': 2, 'S': 2, 'F': 1, 'Cl': 1, 'Br': 1, 'P': 3, 'H': 1} |
|
|
for i in range(self.max_atoms): |
|
|
bond_vars = [self.bond_existence_var(i, j) for j in range(self.max_atoms) if i != j] |
|
|
for atom_type, val in valence_rules.items(): |
|
|
type_var = self.atom_type_var(i, atom_type) |
|
|
if val > len(bond_vars): |
|
|
clauses.append([-type_var]) |
|
|
continue |
|
|
for cl in self._cardinality_at_least(bond_vars, val) + self._cardinality_at_most(bond_vars, val): |
|
|
if cl: clauses.append([-type_var] + cl) |
|
|
return clauses |
|
|
|
|
|
def _cardinality_at_least(self, V, k): |
|
|
n = len(V) |
|
|
if k <= 0: return [] |
|
|
if n < k: return [[1, -1]] |
|
|
if k == 1 and n > 0: return [V] |
|
|
clauses = [] |
|
|
s = [[self.var_offset + i * k + j for j in range(k)] for i in range(n)] |
|
|
self.var_offset += n * k |
|
|
clauses.append([-V[0], s[0][0]]) |
|
|
for j in range(1, k): clauses.append([-s[0][j]]) |
|
|
for i in range(1, n): |
|
|
clauses.append([-V[i], s[i][0]]) |
|
|
clauses.append([-s[i-1][0], s[i][0]]) |
|
|
for j in range(1, k): |
|
|
clauses.append([-V[i], -s[i-1][j-1], s[i][j]]) |
|
|
clauses.append([-s[i-1][j], s[i][j]]) |
|
|
clauses.append([s[n-1][k-1]]) |
|
|
return clauses |
|
|
|
|
|
def _cardinality_at_most(self, V, k): |
|
|
n = len(V) |
|
|
if k < 0: return [[1, -1]] |
|
|
if k >= n: return [] |
|
|
return self._cardinality_at_least([-v for v in V], n - k) |
|
|
|
|
|
def _encode_structural_validity(self): |
|
|
clauses = [] |
|
|
for i in range(self.max_atoms): |
|
|
v = [self.atom_type_var(i, t) for t in self.atom_types] |
|
|
clauses.append(v) |
|
|
for i1 in range(len(v)): |
|
|
for i2 in range(i1 + 1, len(v)): clauses.append([-v[i1], -v[i2]]) |
|
|
return clauses |
|
|
|
|
|
def _encode_single_constraint(self, c): |
|
|
if c.constraint_type == 'aromatic_rings': return self._encode_aromatic_rings(c.value, c.operator) |
|
|
if c.constraint_type == 'molecular_weight': return self._encode_molecular_weight(c.value, c.operator) |
|
|
if c.constraint_type == 'forbidden_group': return self._encode_forbidden_group(c.value) |
|
|
if c.constraint_type == 'synthesizable': return self._encode_synthesizability() |
|
|
return [] |
|
|
|
|
|
def _encode_aromatic_rings(self, v, o): |
|
|
if o == '==': return [[self.aromatic_ring_var(i)] if i < v else [-self.aromatic_ring_var(i)] for i in range(self.max_rings)] |
|
|
return [] |
|
|
|
|
|
def _encode_molecular_weight(self, v, o): |
|
|
c = [] |
|
|
for i in range(len(self.mw_thresholds) - 1): c.append([-self.mw_var(self.mw_thresholds[i+1]), self.mw_var(self.mw_thresholds[i])]) |
|
|
if o == '<': |
|
|
for t in self.mw_thresholds: |
|
|
if t >= v: c.append([-self.mw_var(t)]) |
|
|
return c |
|
|
|
|
|
def _encode_forbidden_group(self, v): |
|
|
if v not in self.functional_groups: return [] |
|
|
return [[-self.functional_group_var(v)]] |
|
|
|
|
|
def _encode_synthesizability(self): |
|
|
c = [[-self.aromatic_ring_var(i)] for i in range(3, self.max_rings)] |
|
|
rg = ['nitro', 'azide', 'peroxide', 'isocyanate'] |
|
|
rv = [self.functional_group_var(g) for g in rg if g in self.functional_groups] |
|
|
for i in range(len(rv)): |
|
|
for j in range(i + 1, len(rv)): c.append([-rv[i], -rv[j]]) |
|
|
return c |
|
|
|
|
|
def _convert_to_3sat(self, cs): |
|
|
s3c, nxt = [], self.var_offset |
|
|
for c in cs: |
|
|
if not c: continue |
|
|
if len(c) <= 3: |
|
|
while len(c) < 3: c.append(c[-1]) |
|
|
s3c.append(c) |
|
|
else: |
|
|
rem = list(c) |
|
|
while len(rem) > 3: |
|
|
l1, l2 = rem.pop(0), rem.pop(0) |
|
|
s3c.append([l1, l2, nxt]); rem.insert(0, -nxt); nxt += 1 |
|
|
s3c.append(rem) |
|
|
self.var_offset = nxt |
|
|
return s3c, self.var_offset - 1 |
|
|
|
|
|
def decode_solution(self, a): |
|
|
s = {'atoms': [], 'bonds': [], 'aromatic_rings': 0, 'functional_groups': [], 'molecular_weight_range': None} |
|
|
if not isinstance(a, np.ndarray) or a.ndim != 1: return s |
|
|
existing_atom_ids = set() |
|
|
for i in range(self.max_atoms): |
|
|
none_var_idx = self.atom_type_var(i, 'None') - 1 |
|
|
if none_var_idx < len(a) and not a[none_var_idx]: |
|
|
for t in self.atom_types: |
|
|
if t == 'None': continue |
|
|
v = self.atom_type_var(i, t) - 1 |
|
|
if v < len(a) and a[v]: |
|
|
s['atoms'].append({'id': i, 'element': t}) |
|
|
existing_atom_ids.add(i) |
|
|
break |
|
|
for i in range(self.max_atoms): |
|
|
for j in range(i + 1, self.max_atoms): |
|
|
v = self.bond_existence_var(i, j) |
|
|
if v != -1 and v - 1 < len(a) and a[v-1]: |
|
|
if i in existing_atom_ids and j in existing_atom_ids: |
|
|
s['bonds'].append({'from': i, 'to': j}) |
|
|
s['aromatic_rings'] = sum(1 for i in range(self.max_rings) if self.aromatic_ring_var(i)-1 < len(a) and a[self.aromatic_ring_var(i)-1]) |
|
|
s['functional_groups'] = [g for g in self.functional_groups if self.functional_group_var(g)-1 < len(a) and a[self.functional_group_var(g)-1]] |
|
|
mw_min = 0 |
|
|
for t in self.mw_thresholds: |
|
|
v = self.mw_var(t) - 1 |
|
|
if v < len(a) and a[v]: mw_min = t |
|
|
else: break |
|
|
s['molecular_weight_range'] = (mw_min, mw_min + 10) |
|
|
return s |
|
|
|
|
|
def parse_constraints(ss): |
|
|
cs = [] |
|
|
for s in ss: |
|
|
s = s.strip() |
|
|
m = re.match(r'(\w+)\s*([<>=!]+)\s*(\d+)', s) |
|
|
if m: |
|
|
name, op, val_str = m.groups() |
|
|
cs.append(MolecularConstraint(name, int(val_str), op)) |
|
|
elif s.startswith('NOT '): cs.append(MolecularConstraint('forbidden_group', s[4:].strip())) |
|
|
elif s in ['synthesizable']: cs.append(MolecularConstraint(s, True)) |
|
|
return cs |