|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import time |
|
|
from collections import defaultdict |
|
|
import json |
|
|
import io |
|
|
|
|
|
try: |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import Draw |
|
|
from rdkit.Chem import rdMolDraw2D |
|
|
RDKIT_AVAILABLE = True |
|
|
from PIL import Image |
|
|
except ImportError: |
|
|
RDKIT_AVAILABLE = False |
|
|
|
|
|
from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints |
|
|
|
|
|
class SparsePhaseCalciumField3SAT: |
|
|
def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045, |
|
|
prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300): |
|
|
np.random.seed(seed) |
|
|
self.N, self.M, self.clauses = N_vars, len(clauses), clauses |
|
|
self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT |
|
|
self.drive, self.max_steps = drive, solver_steps |
|
|
self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M) |
|
|
self.W = defaultdict(dict) |
|
|
for _ in range(min(self.N * 2, 20000)): |
|
|
i, j = np.random.randint(0, self.N, 2) |
|
|
if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05) |
|
|
self.history = {'satisfaction': []} |
|
|
|
|
|
def get_assignment(self): return np.cos(self.phases) > 0 |
|
|
|
|
|
def evaluate_clause(self, clause, assignment): |
|
|
for lit in clause: |
|
|
idx = abs(lit) - 1 |
|
|
if idx >= self.N: continue |
|
|
val = assignment[idx] |
|
|
if (lit > 0 and val) or (lit < 0 and not val): return True |
|
|
return False |
|
|
|
|
|
def compute_satisfaction(self, assignment=None): |
|
|
if assignment is None: assignment = self.get_assignment() |
|
|
if self.M == 0: return 1.0 |
|
|
return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M |
|
|
|
|
|
def step(self): |
|
|
dphi, assignment = np.zeros(self.N), self.get_assignment() |
|
|
for idx, clause in enumerate(self.clauses): |
|
|
if not self.evaluate_clause(clause, assignment): |
|
|
self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0) |
|
|
lit = clause[np.random.randint(len(clause))] |
|
|
idx_var = abs(lit) - 1 |
|
|
if idx_var >= self.N: continue |
|
|
target = 0.0 if lit > 0 else np.pi |
|
|
dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var]) |
|
|
for i in self.W: |
|
|
for j, w in self.W[i].items(): |
|
|
p_diff = self.phases[j] - self.phases[i] |
|
|
dphi[i] += self.K * w * np.sin(p_diff) |
|
|
dphi[j] -= self.K * w * np.sin(p_diff) |
|
|
dphi += self.noise * np.random.randn(self.N) |
|
|
self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi) |
|
|
if np.random.rand() < 0.1: |
|
|
for _ in range(20): |
|
|
i, j = np.random.randint(0, self.N, 2) |
|
|
if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98: |
|
|
self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta) |
|
|
if self.W: |
|
|
s = np.random.choice(list(self.W.keys())) |
|
|
if self.W[s]: |
|
|
t = np.random.choice(list(self.W[s].keys())) |
|
|
self.W[s][t] *= (1 - self.prune_rate) |
|
|
if self.W[s][t] < 0.01: del self.W[s][t] |
|
|
self.history['satisfaction'].append(self.compute_satisfaction()) |
|
|
|
|
|
def draw_molecule_from_structure(s_dict): |
|
|
"""Draw raw graph with atom labels, falling back to text if needed.""" |
|
|
atoms = s_dict.get('atoms', []) |
|
|
bonds = s_dict.get('bonds', []) |
|
|
|
|
|
|
|
|
def get_text_fallback(): |
|
|
if not atoms: return "No atoms in structure." |
|
|
adj = {a['id']: [] for a in atoms} |
|
|
for b in bonds: |
|
|
|
|
|
if b['from'] in adj and b['to'] in adj: |
|
|
adj[b['from']].append(b['to']) |
|
|
adj[b['to']].append(b['from']) |
|
|
lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj.get(a['id'], [])))}" for a in atoms] |
|
|
return "\n".join(lines) |
|
|
|
|
|
if not RDKIT_AVAILABLE: |
|
|
return get_text_fallback() |
|
|
|
|
|
try: |
|
|
mol = Chem.RWMol() |
|
|
atom_map = {info['id']: mol.AddAtom(Chem.Atom(info['element'])) for info in atoms} |
|
|
for bond in bonds: |
|
|
a, b = bond['from'], bond['to'] |
|
|
if a in atom_map and b in atom_map: |
|
|
mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE) |
|
|
if mol.GetNumAtoms() == 0: |
|
|
return None |
|
|
|
|
|
rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()} |
|
|
drawer = rdMolDraw2D.MolDraw2DCairo(300, 300) |
|
|
opts = drawer.drawOptions() |
|
|
for idx in range(mol.GetNumAtoms()): |
|
|
original_id = rdkit_idx_to_original_id.get(idx, '?') |
|
|
symbol = mol.GetAtomWithIdx(idx).GetSymbol() |
|
|
opts.atomLabels[idx] = f"{original_id}:{symbol}" |
|
|
rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol) |
|
|
drawer.FinishDrawing() |
|
|
png = drawer.GetDrawingText() |
|
|
return Image.open(io.BytesIO(png)) |
|
|
except Exception: |
|
|
return get_text_fallback() |
|
|
|
|
|
st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬") |
|
|
st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True) |
|
|
st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True) |
|
|
st.markdown('<div class="sub-header">Generate molecules satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.header("Constraint Configuration") |
|
|
st.sidebar.subheader("Chemical Properties") |
|
|
aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1) |
|
|
max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10) |
|
|
forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], []) |
|
|
|
|
|
st.sidebar.subheader("Additional Constraints") |
|
|
min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.") |
|
|
synthesizable = st.sidebar.checkbox("Synthesizable", value=False) |
|
|
max_atoms = 30 |
|
|
|
|
|
st.sidebar.subheader("Solver Parameters") |
|
|
n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5) |
|
|
solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300) |
|
|
drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0) |
|
|
|
|
|
if st.sidebar.button("🧬 Generate Molecules", type="primary"): |
|
|
with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."): |
|
|
try: |
|
|
constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"] |
|
|
if min_atoms > 0: |
|
|
constraints_list.append(f"min_atoms >= {min_atoms}") |
|
|
for group in forbidden_groups: constraints_list.append(f"NOT {group}") |
|
|
if synthesizable: constraints_list.append("synthesizable") |
|
|
|
|
|
constraints = parse_constraints(constraints_list) |
|
|
encoder = MolecularConstraintEncoder(max_atoms=max_atoms) |
|
|
clauses, n_vars = encoder.encode_constraints(constraints) |
|
|
|
|
|
st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.") |
|
|
results = [] |
|
|
progress_bar = st.progress(0, text="Generating molecules...") |
|
|
|
|
|
for i in range(n_molecules): |
|
|
solver = SparsePhaseCalciumField3SAT( |
|
|
N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i, |
|
|
drive=drive_strength, solver_steps=solver_steps |
|
|
) |
|
|
for _ in range(solver_steps): solver.step() |
|
|
|
|
|
assignment = solver.get_assignment() |
|
|
structure = encoder.decode_solution(assignment) |
|
|
structure['satisfaction'] = solver.compute_satisfaction() |
|
|
structure['molecule_id'] = i + 1 |
|
|
results.append(structure) |
|
|
progress_bar.progress((i + 1) / n_molecules) |
|
|
|
|
|
st.session_state['results'] = results |
|
|
st.success(f"Successfully generated {n_molecules} molecular structures!") |
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {e}") |
|
|
import traceback |
|
|
st.code(traceback.format_exc()) |
|
|
|
|
|
if 'results' in st.session_state: |
|
|
results = st.session_state['results'] |
|
|
st.subheader("Generated Molecules") |
|
|
cols = st.columns(min(len(results), 5)) |
|
|
for i, res in enumerate(results): |
|
|
with cols[i % 5]: |
|
|
st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.") |
|
|
|
|
|
|
|
|
output = draw_molecule_from_structure(res) |
|
|
if isinstance(output, str): |
|
|
st.code(output) |
|
|
elif output is not None: |
|
|
st.image(output) |
|
|
else: |
|
|
st.warning("Could not draw.") |
|
|
|
|
|
with st.expander("Details"): |
|
|
st.json(res) |