Upload 2 files
Browse files- app.py +191 -0
- molecular_constraint_solver.py +228 -0
app.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import json
|
| 8 |
+
import io
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
from rdkit.Chem import Draw
|
| 13 |
+
from rdkit.Chem import rdMolDraw2D
|
| 14 |
+
RDKIT_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
RDKIT_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints
|
| 19 |
+
|
| 20 |
+
class SparsePhaseCalciumField3SAT:
|
| 21 |
+
def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045,
|
| 22 |
+
prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300):
|
| 23 |
+
np.random.seed(seed)
|
| 24 |
+
self.N, self.M, self.clauses = N_vars, len(clauses), clauses
|
| 25 |
+
self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT
|
| 26 |
+
self.drive, self.max_steps = drive, solver_steps
|
| 27 |
+
self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M)
|
| 28 |
+
self.W = defaultdict(dict)
|
| 29 |
+
for _ in range(min(self.N * 2, 20000)):
|
| 30 |
+
i, j = np.random.randint(0, self.N, 2)
|
| 31 |
+
if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05)
|
| 32 |
+
self.history = {'satisfaction': []}
|
| 33 |
+
|
| 34 |
+
def get_assignment(self): return np.cos(self.phases) > 0
|
| 35 |
+
|
| 36 |
+
def evaluate_clause(self, clause, assignment):
|
| 37 |
+
for lit in clause:
|
| 38 |
+
idx = abs(lit) - 1
|
| 39 |
+
if idx >= self.N: continue
|
| 40 |
+
val = assignment[idx]
|
| 41 |
+
if (lit > 0 and val) or (lit < 0 and not val): return True
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
def compute_satisfaction(self, assignment=None):
|
| 45 |
+
if assignment is None: assignment = self.get_assignment()
|
| 46 |
+
if self.M == 0: return 1.0
|
| 47 |
+
return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M
|
| 48 |
+
|
| 49 |
+
def step(self):
|
| 50 |
+
dphi, assignment = np.zeros(self.N), self.get_assignment()
|
| 51 |
+
for idx, clause in enumerate(self.clauses):
|
| 52 |
+
if not self.evaluate_clause(clause, assignment):
|
| 53 |
+
self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0)
|
| 54 |
+
lit = clause[np.random.randint(len(clause))]
|
| 55 |
+
idx_var = abs(lit) - 1
|
| 56 |
+
if idx_var >= self.N: continue
|
| 57 |
+
target = 0.0 if lit > 0 else np.pi
|
| 58 |
+
dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var])
|
| 59 |
+
for i in self.W:
|
| 60 |
+
for j, w in self.W[i].items():
|
| 61 |
+
p_diff = self.phases[j] - self.phases[i]
|
| 62 |
+
dphi[i] += self.K * w * np.sin(p_diff)
|
| 63 |
+
dphi[j] -= self.K * w * np.sin(p_diff)
|
| 64 |
+
dphi += self.noise * np.random.randn(self.N)
|
| 65 |
+
self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi)
|
| 66 |
+
if np.random.rand() < 0.1:
|
| 67 |
+
for _ in range(20):
|
| 68 |
+
i, j = np.random.randint(0, self.N, 2)
|
| 69 |
+
if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98:
|
| 70 |
+
self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta)
|
| 71 |
+
if self.W:
|
| 72 |
+
s = np.random.choice(list(self.W.keys()))
|
| 73 |
+
if self.W[s]:
|
| 74 |
+
t = np.random.choice(list(self.W[s].keys()))
|
| 75 |
+
self.W[s][t] *= (1 - self.prune_rate)
|
| 76 |
+
if self.W[s][t] < 0.01: del self.W[s][t]
|
| 77 |
+
self.history['satisfaction'].append(self.compute_satisfaction())
|
| 78 |
+
|
| 79 |
+
def draw_molecule_from_structure(s_dict):
|
| 80 |
+
if not RDKIT_AVAILABLE:
|
| 81 |
+
atoms = s_dict.get('atoms', [])
|
| 82 |
+
bonds = s_dict.get('bonds', [])
|
| 83 |
+
if not atoms: return "No atoms to draw."
|
| 84 |
+
adj = {a['id']: [] for a in atoms}
|
| 85 |
+
for b in bonds:
|
| 86 |
+
adj[b['from']].append(b['to'])
|
| 87 |
+
adj[b['to']].append(b['from'])
|
| 88 |
+
lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj[a['id']]))}" for a in atoms]
|
| 89 |
+
return "\n".join(lines)
|
| 90 |
+
try:
|
| 91 |
+
mol = Chem.RWMol()
|
| 92 |
+
atom_map = {}
|
| 93 |
+
for info in s_dict.get('atoms', []):
|
| 94 |
+
atom = Chem.Atom(info['element'])
|
| 95 |
+
idx = mol.AddAtom(atom)
|
| 96 |
+
atom_map[info['id']] = idx
|
| 97 |
+
for bond in s_dict.get('bonds', []):
|
| 98 |
+
a, b = bond['from'], bond['to']
|
| 99 |
+
if a in atom_map and b in atom_map:
|
| 100 |
+
mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE)
|
| 101 |
+
if mol.GetNumAtoms() == 0: return None
|
| 102 |
+
rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()}
|
| 103 |
+
drawer = rdMolDraw2D.MolDraw2DCairo(300, 300)
|
| 104 |
+
opts = drawer.drawOptions()
|
| 105 |
+
for idx in range(mol.GetNumAtoms()):
|
| 106 |
+
original_id = rdkit_idx_to_original_id.get(idx, '?')
|
| 107 |
+
symbol = mol.GetAtomWithIdx(idx).GetSymbol()
|
| 108 |
+
opts.atomLabels[idx] = f"{original_id}:{symbol}"
|
| 109 |
+
rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol)
|
| 110 |
+
drawer.FinishDrawing()
|
| 111 |
+
png = drawer.GetDrawingText()
|
| 112 |
+
from PIL import Image
|
| 113 |
+
return Image.open(io.BytesIO(png))
|
| 114 |
+
except Exception as e:
|
| 115 |
+
return f"RDKit drawing failed: {e}"
|
| 116 |
+
|
| 117 |
+
st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬")
|
| 118 |
+
st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True)
|
| 119 |
+
st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True)
|
| 120 |
+
st.markdown('<div class="sub-header">Generate molecular graphs satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True)
|
| 121 |
+
|
| 122 |
+
st.sidebar.header("Constraint Configuration")
|
| 123 |
+
st.sidebar.subheader("Chemical Properties")
|
| 124 |
+
aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1)
|
| 125 |
+
max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10)
|
| 126 |
+
forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], [])
|
| 127 |
+
|
| 128 |
+
st.sidebar.subheader("Additional Constraints")
|
| 129 |
+
min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.")
|
| 130 |
+
synthesizable = st.sidebar.checkbox("Synthesizable", value=False)
|
| 131 |
+
max_atoms = 30
|
| 132 |
+
|
| 133 |
+
st.sidebar.subheader("Solver Parameters")
|
| 134 |
+
n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5)
|
| 135 |
+
solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300)
|
| 136 |
+
drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0)
|
| 137 |
+
|
| 138 |
+
if st.sidebar.button("🧬 Generate Molecules", type="primary"):
|
| 139 |
+
with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."):
|
| 140 |
+
try:
|
| 141 |
+
constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"]
|
| 142 |
+
if min_atoms > 0:
|
| 143 |
+
constraints_list.append(f"min_atoms >= {min_atoms}")
|
| 144 |
+
for group in forbidden_groups: constraints_list.append(f"NOT {group}")
|
| 145 |
+
if synthesizable: constraints_list.append("synthesizable")
|
| 146 |
+
|
| 147 |
+
constraints = parse_constraints(constraints_list)
|
| 148 |
+
encoder = MolecularConstraintEncoder(max_atoms=max_atoms)
|
| 149 |
+
clauses, n_vars = encoder.encode_constraints(constraints)
|
| 150 |
+
|
| 151 |
+
st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.")
|
| 152 |
+
results = []
|
| 153 |
+
progress_bar = st.progress(0, text="Generating molecules...")
|
| 154 |
+
|
| 155 |
+
for i in range(n_molecules):
|
| 156 |
+
solver = SparsePhaseCalciumField3SAT(
|
| 157 |
+
N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i,
|
| 158 |
+
drive=drive_strength, solver_steps=solver_steps
|
| 159 |
+
)
|
| 160 |
+
for _ in range(solver_steps): solver.step()
|
| 161 |
+
|
| 162 |
+
assignment = solver.get_assignment()
|
| 163 |
+
structure = encoder.decode_solution(assignment)
|
| 164 |
+
structure['satisfaction'] = solver.compute_satisfaction()
|
| 165 |
+
structure['molecule_id'] = i + 1
|
| 166 |
+
results.append(structure)
|
| 167 |
+
progress_bar.progress((i + 1) / n_molecules)
|
| 168 |
+
|
| 169 |
+
st.session_state['results'] = results
|
| 170 |
+
st.success(f"Successfully generated {n_molecules} molecular structures!")
|
| 171 |
+
except Exception as e:
|
| 172 |
+
st.error(f"An error occurred: {e}")
|
| 173 |
+
import traceback
|
| 174 |
+
st.code(traceback.format_exc())
|
| 175 |
+
|
| 176 |
+
if 'results' in st.session_state:
|
| 177 |
+
results = st.session_state['results']
|
| 178 |
+
st.subheader("Generated Molecules")
|
| 179 |
+
cols = st.columns(min(len(results), 5))
|
| 180 |
+
for i, res in enumerate(results):
|
| 181 |
+
with cols[i % 5]:
|
| 182 |
+
st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.")
|
| 183 |
+
output = draw_molecule_from_structure(res)
|
| 184 |
+
if isinstance(output, str):
|
| 185 |
+
st.code(output)
|
| 186 |
+
elif output is not None:
|
| 187 |
+
st.image(output)
|
| 188 |
+
else:
|
| 189 |
+
st.warning("Could not draw.")
|
| 190 |
+
with st.expander("Details"):
|
| 191 |
+
st.json(res)
|
molecular_constraint_solver.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# molecular_constraint_solver.py
|
| 2 |
+
# FINAL VERSION with corrected decoder
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Dict, Tuple
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class MolecularConstraint:
|
| 11 |
+
constraint_type: str
|
| 12 |
+
value: any
|
| 13 |
+
operator: str = '=='
|
| 14 |
+
|
| 15 |
+
class MolecularConstraintEncoder:
|
| 16 |
+
def __init__(self, max_atoms=30):
|
| 17 |
+
self.max_atoms = max_atoms
|
| 18 |
+
self.max_bonds = max_atoms * (max_atoms - 1) // 2
|
| 19 |
+
self.var_offset = 1
|
| 20 |
+
self.atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'P', 'H', 'None']
|
| 21 |
+
self.atom_var_start = self.var_offset
|
| 22 |
+
self.var_offset += self.max_atoms * len(self.atom_types)
|
| 23 |
+
self.bond_existence_var_start = self.var_offset
|
| 24 |
+
self.var_offset += self.max_bonds
|
| 25 |
+
self.conn_var_start = self.var_offset
|
| 26 |
+
self.var_offset += self.max_atoms
|
| 27 |
+
self.bond_types = ['single', 'double', 'triple']
|
| 28 |
+
self.bond_type_var_start = self.var_offset
|
| 29 |
+
self.var_offset += self.max_bonds * len(self.bond_types)
|
| 30 |
+
self.ring_var_start = self.var_offset
|
| 31 |
+
self.var_offset += self.max_atoms
|
| 32 |
+
self.max_rings = 10
|
| 33 |
+
self.aromatic_var_start = self.var_offset
|
| 34 |
+
self.var_offset += self.max_rings
|
| 35 |
+
self.functional_groups = ['nitro', 'azide', 'peroxide', 'aldehyde', 'ketone', 'carboxyl', 'amine', 'amide', 'ester', 'ether', 'thiol', 'sulfone', 'phosphate', 'hydroxyl', 'halogen', 'cyano', 'isocyanate', 'epoxide', 'lactone', 'quinone']
|
| 36 |
+
self.group_var_start = self.var_offset
|
| 37 |
+
self.var_offset += len(self.functional_groups)
|
| 38 |
+
self.mw_thresholds = list(range(100, 600, 10))
|
| 39 |
+
self.mw_var_start = self.var_offset
|
| 40 |
+
self.var_offset += len(self.mw_thresholds)
|
| 41 |
+
|
| 42 |
+
def atom_type_var(self, atom_idx, atom_type):
|
| 43 |
+
return self.atom_var_start + atom_idx * len(self.atom_types) + self.atom_types.index(atom_type)
|
| 44 |
+
|
| 45 |
+
def bond_existence_var(self, i, j):
|
| 46 |
+
if i == j: return -1
|
| 47 |
+
if i > j: i, j = j, i
|
| 48 |
+
idx = int(i * (self.max_atoms - (i + 1) / 2.0) + (j - i - 1))
|
| 49 |
+
return self.bond_existence_var_start + idx
|
| 50 |
+
|
| 51 |
+
def conn_var(self, atom_idx):
|
| 52 |
+
return self.conn_var_start + atom_idx
|
| 53 |
+
|
| 54 |
+
def atom_exists_lit(self, atom_idx):
|
| 55 |
+
return -self.atom_type_var(atom_idx, 'None')
|
| 56 |
+
|
| 57 |
+
def ring_var(self, idx): return self.ring_var_start + idx
|
| 58 |
+
def aromatic_ring_var(self, idx): return self.aromatic_var_start + idx
|
| 59 |
+
def functional_group_var(self, g): return self.group_var_start + self.functional_groups.index(g)
|
| 60 |
+
def mw_var(self, t): return self.mw_var_start + self.mw_thresholds.index(min(self.mw_thresholds, key=lambda x: abs(x-t)))
|
| 61 |
+
|
| 62 |
+
def encode_constraints(self, constraints: List[MolecularConstraint]) -> Tuple[List[List[int]], int]:
|
| 63 |
+
all_clauses = self._encode_structural_validity()
|
| 64 |
+
all_clauses.extend(self.encode_valence())
|
| 65 |
+
all_clauses.extend(self._encode_connectivity())
|
| 66 |
+
for constraint in constraints:
|
| 67 |
+
all_clauses.extend(self._encode_single_constraint(constraint))
|
| 68 |
+
return self._convert_to_3sat(all_clauses)
|
| 69 |
+
|
| 70 |
+
def _encode_connectivity(self):
|
| 71 |
+
clauses = []
|
| 72 |
+
clauses.append([self.atom_type_var(0, 'None'), self.conn_var(0)])
|
| 73 |
+
clauses.append([-self.atom_type_var(0, 'None'), -self.conn_var(0)])
|
| 74 |
+
for i in range(self.max_atoms):
|
| 75 |
+
for j in range(i + 1, self.max_atoms):
|
| 76 |
+
bond_var = self.bond_existence_var(i, j)
|
| 77 |
+
clauses.append([-self.conn_var(i), -bond_var, self.conn_var(j)])
|
| 78 |
+
clauses.append([-self.conn_var(j), -bond_var, self.conn_var(i)])
|
| 79 |
+
for i in range(self.max_atoms):
|
| 80 |
+
clauses.append([self.atom_type_var(i, 'None'), self.conn_var(i)])
|
| 81 |
+
return clauses
|
| 82 |
+
|
| 83 |
+
def encode_valence(self):
|
| 84 |
+
clauses = []
|
| 85 |
+
valence_rules = {'C': 4, 'N': 3, 'O': 2, 'S': 2, 'F': 1, 'Cl': 1, 'Br': 1, 'P': 3, 'H': 1}
|
| 86 |
+
for i in range(self.max_atoms):
|
| 87 |
+
bond_vars = [self.bond_existence_var(i, j) for j in range(self.max_atoms) if i != j]
|
| 88 |
+
for atom_type, val in valence_rules.items():
|
| 89 |
+
type_var = self.atom_type_var(i, atom_type)
|
| 90 |
+
if val > len(bond_vars):
|
| 91 |
+
clauses.append([-type_var])
|
| 92 |
+
continue
|
| 93 |
+
for cl in self._cardinality_at_least(bond_vars, val) + self._cardinality_at_most(bond_vars, val):
|
| 94 |
+
if cl: clauses.append([-type_var] + cl)
|
| 95 |
+
return clauses
|
| 96 |
+
|
| 97 |
+
def _cardinality_at_least(self, V, k):
|
| 98 |
+
n = len(V)
|
| 99 |
+
if k <= 0: return []
|
| 100 |
+
if n < k: return [[1, -1]]
|
| 101 |
+
if k == 1 and n > 0: return [V]
|
| 102 |
+
clauses = []
|
| 103 |
+
s = [[self.var_offset + i * k + j for j in range(k)] for i in range(n)]
|
| 104 |
+
self.var_offset += n * k
|
| 105 |
+
clauses.append([-V[0], s[0][0]])
|
| 106 |
+
for j in range(1, k): clauses.append([-s[0][j]])
|
| 107 |
+
for i in range(1, n):
|
| 108 |
+
clauses.append([-V[i], s[i][0]])
|
| 109 |
+
clauses.append([-s[i-1][0], s[i][0]])
|
| 110 |
+
for j in range(1, k):
|
| 111 |
+
clauses.append([-V[i], -s[i-1][j-1], s[i][j]])
|
| 112 |
+
clauses.append([-s[i-1][j], s[i][j]])
|
| 113 |
+
clauses.append([s[n-1][k-1]])
|
| 114 |
+
return clauses
|
| 115 |
+
|
| 116 |
+
def _cardinality_at_most(self, V, k):
|
| 117 |
+
n = len(V)
|
| 118 |
+
if k < 0: return [[1, -1]]
|
| 119 |
+
if k >= n: return []
|
| 120 |
+
return self._cardinality_at_least([-v for v in V], n - k)
|
| 121 |
+
|
| 122 |
+
def _encode_structural_validity(self):
|
| 123 |
+
clauses = []
|
| 124 |
+
for i in range(self.max_atoms):
|
| 125 |
+
v = [self.atom_type_var(i, t) for t in self.atom_types]
|
| 126 |
+
clauses.append(v)
|
| 127 |
+
for i1 in range(len(v)):
|
| 128 |
+
for i2 in range(i1 + 1, len(v)): clauses.append([-v[i1], -v[i2]])
|
| 129 |
+
return clauses
|
| 130 |
+
|
| 131 |
+
def _encode_single_constraint(self, c):
|
| 132 |
+
if c.constraint_type == 'min_atoms': return self._encode_min_atoms(c.value)
|
| 133 |
+
if c.constraint_type == 'aromatic_rings': return self._encode_aromatic_rings(c.value, c.operator)
|
| 134 |
+
if c.constraint_type == 'molecular_weight': return self._encode_molecular_weight(c.value, c.operator)
|
| 135 |
+
if c.constraint_type == 'forbidden_group': return self._encode_forbidden_group(c.value)
|
| 136 |
+
if c.constraint_type == 'synthesizable': return self._encode_synthesizability()
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
def _encode_min_atoms(self, k):
|
| 140 |
+
if k <= 0: return []
|
| 141 |
+
existence_literals = [self.atom_exists_lit(i) for i in range(self.max_atoms)]
|
| 142 |
+
return self._cardinality_at_least(existence_literals, k)
|
| 143 |
+
|
| 144 |
+
def _encode_aromatic_rings(self, v, o):
|
| 145 |
+
if o == '==': return [[self.aromatic_ring_var(i)] if i < v else [-self.aromatic_ring_var(i)] for i in range(self.max_rings)]
|
| 146 |
+
return []
|
| 147 |
+
|
| 148 |
+
def _encode_molecular_weight(self, v, o):
|
| 149 |
+
c = []
|
| 150 |
+
for i in range(len(self.mw_thresholds) - 1): c.append([-self.mw_var(self.mw_thresholds[i+1]), self.mw_var(self.mw_thresholds[i])])
|
| 151 |
+
if o == '<':
|
| 152 |
+
for t in self.mw_thresholds:
|
| 153 |
+
if t >= v: c.append([-self.mw_var(t)])
|
| 154 |
+
return c
|
| 155 |
+
|
| 156 |
+
def _encode_forbidden_group(self, v):
|
| 157 |
+
if v not in self.functional_groups: return []
|
| 158 |
+
return [[-self.functional_group_var(v)]]
|
| 159 |
+
|
| 160 |
+
def _encode_synthesizability(self):
|
| 161 |
+
c = [[-self.aromatic_ring_var(i)] for i in range(3, self.max_rings)]
|
| 162 |
+
rg = ['nitro', 'azide', 'peroxide', 'isocyanate']
|
| 163 |
+
rv = [self.functional_group_var(g) for g in rg if g in self.functional_groups]
|
| 164 |
+
for i in range(len(rv)):
|
| 165 |
+
for j in range(i + 1, len(rv)): c.append([-rv[i], -rv[j]])
|
| 166 |
+
return c
|
| 167 |
+
|
| 168 |
+
def _convert_to_3sat(self, cs):
|
| 169 |
+
s3c, nxt = [], self.var_offset
|
| 170 |
+
for c in cs:
|
| 171 |
+
if not c: continue
|
| 172 |
+
if len(c) <= 3:
|
| 173 |
+
while len(c) < 3: c.append(c[-1])
|
| 174 |
+
s3c.append(c)
|
| 175 |
+
else:
|
| 176 |
+
rem = list(c)
|
| 177 |
+
while len(rem) > 3:
|
| 178 |
+
l1, l2 = rem.pop(0), rem.pop(0)
|
| 179 |
+
s3c.append([l1, l2, nxt]); rem.insert(0, -nxt); nxt += 1
|
| 180 |
+
s3c.append(rem)
|
| 181 |
+
self.var_offset = nxt
|
| 182 |
+
return s3c, self.var_offset - 1
|
| 183 |
+
|
| 184 |
+
# <<< MODIFIED: Robust decoder to prevent ghost bonds >>>
|
| 185 |
+
def decode_solution(self, a):
|
| 186 |
+
s = {'atoms': [], 'bonds': [], 'aromatic_rings': 0, 'functional_groups': [], 'molecular_weight_range': None}
|
| 187 |
+
if not isinstance(a, np.ndarray) or a.ndim != 1: return s
|
| 188 |
+
|
| 189 |
+
# Step 1: Decode atoms and create a set of valid, existing atom IDs
|
| 190 |
+
existing_atom_ids = set()
|
| 191 |
+
for i in range(self.max_atoms):
|
| 192 |
+
for t in self.atom_types:
|
| 193 |
+
v = self.atom_type_var(i, t) - 1
|
| 194 |
+
if v < len(a) and a[v] and t != 'None':
|
| 195 |
+
s['atoms'].append({'id': i, 'element': t})
|
| 196 |
+
existing_atom_ids.add(i)
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
# Step 2: Decode bonds, but only if BOTH atoms in the bond exist
|
| 200 |
+
for i in range(self.max_atoms):
|
| 201 |
+
for j in range(i + 1, self.max_atoms):
|
| 202 |
+
v = self.bond_existence_var(i, j)
|
| 203 |
+
if v != -1 and v - 1 < len(a) and a[v-1]:
|
| 204 |
+
# FIX: Check if both atoms are in our set of existing atoms
|
| 205 |
+
if i in existing_atom_ids and j in existing_atom_ids:
|
| 206 |
+
s['bonds'].append({'from': i, 'to': j})
|
| 207 |
+
|
| 208 |
+
s['aromatic_rings'] = sum(1 for i in range(self.max_rings) if self.aromatic_ring_var(i)-1 < len(a) and a[self.aromatic_ring_var(i)-1])
|
| 209 |
+
s['functional_groups'] = [g for g in self.functional_groups if self.functional_group_var(g)-1 < len(a) and a[self.functional_group_var(g)-1]]
|
| 210 |
+
mw_min = 0
|
| 211 |
+
for t in self.mw_thresholds:
|
| 212 |
+
v = self.mw_var(t) - 1
|
| 213 |
+
if v < len(a) and a[v]: mw_min = t
|
| 214 |
+
else: break
|
| 215 |
+
s['molecular_weight_range'] = (mw_min, mw_min + 10)
|
| 216 |
+
return s
|
| 217 |
+
|
| 218 |
+
def parse_constraints(ss):
|
| 219 |
+
cs = []
|
| 220 |
+
for s in ss:
|
| 221 |
+
s = s.strip()
|
| 222 |
+
m = re.match(r'(\w+)\s*([<>=!]+)\s*(\d+)', s)
|
| 223 |
+
if m:
|
| 224 |
+
name, op, val_str = m.groups()
|
| 225 |
+
cs.append(MolecularConstraint(name, int(val_str), op))
|
| 226 |
+
elif s.startswith('NOT '): cs.append(MolecularConstraint('forbidden_group', s[4:].strip()))
|
| 227 |
+
elif s in ['synthesizable']: cs.append(MolecularConstraint(s, True))
|
| 228 |
+
return cs
|