File size: 9,376 Bytes
7059c33 0f19e3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# molecular_demo2.py
# FINAL VERSION with robust display logic
import streamlit as st
import numpy as np
import time
from collections import defaultdict
import json
import io
try:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdMolDraw2D
RDKIT_AVAILABLE = True
from PIL import Image
except ImportError:
RDKIT_AVAILABLE = False
from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints
class SparsePhaseCalciumField3SAT:
def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045,
prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300):
np.random.seed(seed)
self.N, self.M, self.clauses = N_vars, len(clauses), clauses
self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT
self.drive, self.max_steps = drive, solver_steps
self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M)
self.W = defaultdict(dict)
for _ in range(min(self.N * 2, 20000)):
i, j = np.random.randint(0, self.N, 2)
if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05)
self.history = {'satisfaction': []}
def get_assignment(self): return np.cos(self.phases) > 0
def evaluate_clause(self, clause, assignment):
for lit in clause:
idx = abs(lit) - 1
if idx >= self.N: continue
val = assignment[idx]
if (lit > 0 and val) or (lit < 0 and not val): return True
return False
def compute_satisfaction(self, assignment=None):
if assignment is None: assignment = self.get_assignment()
if self.M == 0: return 1.0
return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M
def step(self):
dphi, assignment = np.zeros(self.N), self.get_assignment()
for idx, clause in enumerate(self.clauses):
if not self.evaluate_clause(clause, assignment):
self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0)
lit = clause[np.random.randint(len(clause))]
idx_var = abs(lit) - 1
if idx_var >= self.N: continue
target = 0.0 if lit > 0 else np.pi
dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var])
for i in self.W:
for j, w in self.W[i].items():
p_diff = self.phases[j] - self.phases[i]
dphi[i] += self.K * w * np.sin(p_diff)
dphi[j] -= self.K * w * np.sin(p_diff)
dphi += self.noise * np.random.randn(self.N)
self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi)
if np.random.rand() < 0.1:
for _ in range(20):
i, j = np.random.randint(0, self.N, 2)
if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98:
self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta)
if self.W:
s = np.random.choice(list(self.W.keys()))
if self.W[s]:
t = np.random.choice(list(self.W[s].keys()))
self.W[s][t] *= (1 - self.prune_rate)
if self.W[s][t] < 0.01: del self.W[s][t]
self.history['satisfaction'].append(self.compute_satisfaction())
def draw_molecule_from_structure(s_dict):
"""Draw raw graph with atom labels, falling back to text if needed."""
atoms = s_dict.get('atoms', [])
bonds = s_dict.get('bonds', [])
# Define the text fallback first
def get_text_fallback():
if not atoms: return "No atoms in structure."
adj = {a['id']: [] for a in atoms}
for b in bonds:
# The source of the KeyError is here, so we add a check
if b['from'] in adj and b['to'] in adj:
adj[b['from']].append(b['to'])
adj[b['to']].append(b['from'])
lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj.get(a['id'], [])))}" for a in atoms]
return "\n".join(lines)
if not RDKIT_AVAILABLE:
return get_text_fallback()
try:
mol = Chem.RWMol()
atom_map = {info['id']: mol.AddAtom(Chem.Atom(info['element'])) for info in atoms}
for bond in bonds:
a, b = bond['from'], bond['to']
if a in atom_map and b in atom_map:
mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE)
if mol.GetNumAtoms() == 0:
return None
rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()}
drawer = rdMolDraw2D.MolDraw2DCairo(300, 300)
opts = drawer.drawOptions()
for idx in range(mol.GetNumAtoms()):
original_id = rdkit_idx_to_original_id.get(idx, '?')
symbol = mol.GetAtomWithIdx(idx).GetSymbol()
opts.atomLabels[idx] = f"{original_id}:{symbol}"
rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol)
drawer.FinishDrawing()
png = drawer.GetDrawingText()
return Image.open(io.BytesIO(png))
except Exception:
return get_text_fallback() # Return text on any RDKit error
st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬")
st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True)
st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True)
st.markdown('<div class="sub-header">Generate molecules satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True)
st.sidebar.header("Constraint Configuration")
st.sidebar.subheader("Chemical Properties")
aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1)
max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10)
forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], [])
st.sidebar.subheader("Additional Constraints")
min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.")
synthesizable = st.sidebar.checkbox("Synthesizable", value=False)
max_atoms = 30
st.sidebar.subheader("Solver Parameters")
n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5)
solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300)
drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0)
if st.sidebar.button("🧬 Generate Molecules", type="primary"):
with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."):
try:
constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"]
if min_atoms > 0:
constraints_list.append(f"min_atoms >= {min_atoms}")
for group in forbidden_groups: constraints_list.append(f"NOT {group}")
if synthesizable: constraints_list.append("synthesizable")
constraints = parse_constraints(constraints_list)
encoder = MolecularConstraintEncoder(max_atoms=max_atoms)
clauses, n_vars = encoder.encode_constraints(constraints)
st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.")
results = []
progress_bar = st.progress(0, text="Generating molecules...")
for i in range(n_molecules):
solver = SparsePhaseCalciumField3SAT(
N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i,
drive=drive_strength, solver_steps=solver_steps
)
for _ in range(solver_steps): solver.step()
assignment = solver.get_assignment()
structure = encoder.decode_solution(assignment)
structure['satisfaction'] = solver.compute_satisfaction()
structure['molecule_id'] = i + 1
results.append(structure)
progress_bar.progress((i + 1) / n_molecules)
st.session_state['results'] = results
st.success(f"Successfully generated {n_molecules} molecular structures!")
except Exception as e:
st.error(f"An error occurred: {e}")
import traceback
st.code(traceback.format_exc())
if 'results' in st.session_state:
results = st.session_state['results']
st.subheader("Generated Molecules")
cols = st.columns(min(len(results), 5))
for i, res in enumerate(results):
with cols[i % 5]:
st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.")
# --- FIX: Check the type of the output before displaying ---
output = draw_molecule_from_structure(res)
if isinstance(output, str):
st.code(output)
elif output is not None:
st.image(output)
else:
st.warning("Could not draw.")
with st.expander("Details"):
st.json(res) |