# molecular_demo2.py # FINAL VERSION with robust display logic import streamlit as st import numpy as np import time from collections import defaultdict import json import io try: from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem import rdMolDraw2D RDKIT_AVAILABLE = True from PIL import Image except ImportError: RDKIT_AVAILABLE = False from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints class SparsePhaseCalciumField3SAT: def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045, prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300): np.random.seed(seed) self.N, self.M, self.clauses = N_vars, len(clauses), clauses self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT self.drive, self.max_steps = drive, solver_steps self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M) self.W = defaultdict(dict) for _ in range(min(self.N * 2, 20000)): i, j = np.random.randint(0, self.N, 2) if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05) self.history = {'satisfaction': []} def get_assignment(self): return np.cos(self.phases) > 0 def evaluate_clause(self, clause, assignment): for lit in clause: idx = abs(lit) - 1 if idx >= self.N: continue val = assignment[idx] if (lit > 0 and val) or (lit < 0 and not val): return True return False def compute_satisfaction(self, assignment=None): if assignment is None: assignment = self.get_assignment() if self.M == 0: return 1.0 return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M def step(self): dphi, assignment = np.zeros(self.N), self.get_assignment() for idx, clause in enumerate(self.clauses): if not self.evaluate_clause(clause, assignment): self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0) lit = clause[np.random.randint(len(clause))] idx_var = abs(lit) - 1 if idx_var >= self.N: continue target = 0.0 if lit > 0 else np.pi dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var]) for i in self.W: for j, w in self.W[i].items(): p_diff = self.phases[j] - self.phases[i] dphi[i] += self.K * w * np.sin(p_diff) dphi[j] -= self.K * w * np.sin(p_diff) dphi += self.noise * np.random.randn(self.N) self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi) if np.random.rand() < 0.1: for _ in range(20): i, j = np.random.randint(0, self.N, 2) if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98: self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta) if self.W: s = np.random.choice(list(self.W.keys())) if self.W[s]: t = np.random.choice(list(self.W[s].keys())) self.W[s][t] *= (1 - self.prune_rate) if self.W[s][t] < 0.01: del self.W[s][t] self.history['satisfaction'].append(self.compute_satisfaction()) def draw_molecule_from_structure(s_dict): """Draw raw graph with atom labels, falling back to text if needed.""" atoms = s_dict.get('atoms', []) bonds = s_dict.get('bonds', []) # Define the text fallback first def get_text_fallback(): if not atoms: return "No atoms in structure." adj = {a['id']: [] for a in atoms} for b in bonds: # The source of the KeyError is here, so we add a check if b['from'] in adj and b['to'] in adj: adj[b['from']].append(b['to']) adj[b['to']].append(b['from']) lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj.get(a['id'], [])))}" for a in atoms] return "\n".join(lines) if not RDKIT_AVAILABLE: return get_text_fallback() try: mol = Chem.RWMol() atom_map = {info['id']: mol.AddAtom(Chem.Atom(info['element'])) for info in atoms} for bond in bonds: a, b = bond['from'], bond['to'] if a in atom_map and b in atom_map: mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE) if mol.GetNumAtoms() == 0: return None rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()} drawer = rdMolDraw2D.MolDraw2DCairo(300, 300) opts = drawer.drawOptions() for idx in range(mol.GetNumAtoms()): original_id = rdkit_idx_to_original_id.get(idx, '?') symbol = mol.GetAtomWithIdx(idx).GetSymbol() opts.atomLabels[idx] = f"{original_id}:{symbol}" rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol) drawer.FinishDrawing() png = drawer.GetDrawingText() return Image.open(io.BytesIO(png)) except Exception: return get_text_fallback() # Return text on any RDKit error st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬") st.markdown("""""", unsafe_allow_html=True) st.markdown('
🧬 Molecular Constraint Solver
', unsafe_allow_html=True) st.markdown('
Generate molecules satisfying hard constraints via neuromorphic 3-SAT solving
', unsafe_allow_html=True) st.sidebar.header("Constraint Configuration") st.sidebar.subheader("Chemical Properties") aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1) max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10) forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], []) st.sidebar.subheader("Additional Constraints") min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.") synthesizable = st.sidebar.checkbox("Synthesizable", value=False) max_atoms = 30 st.sidebar.subheader("Solver Parameters") n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5) solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300) drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0) if st.sidebar.button("🧬 Generate Molecules", type="primary"): with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."): try: constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"] if min_atoms > 0: constraints_list.append(f"min_atoms >= {min_atoms}") for group in forbidden_groups: constraints_list.append(f"NOT {group}") if synthesizable: constraints_list.append("synthesizable") constraints = parse_constraints(constraints_list) encoder = MolecularConstraintEncoder(max_atoms=max_atoms) clauses, n_vars = encoder.encode_constraints(constraints) st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.") results = [] progress_bar = st.progress(0, text="Generating molecules...") for i in range(n_molecules): solver = SparsePhaseCalciumField3SAT( N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i, drive=drive_strength, solver_steps=solver_steps ) for _ in range(solver_steps): solver.step() assignment = solver.get_assignment() structure = encoder.decode_solution(assignment) structure['satisfaction'] = solver.compute_satisfaction() structure['molecule_id'] = i + 1 results.append(structure) progress_bar.progress((i + 1) / n_molecules) st.session_state['results'] = results st.success(f"Successfully generated {n_molecules} molecular structures!") except Exception as e: st.error(f"An error occurred: {e}") import traceback st.code(traceback.format_exc()) if 'results' in st.session_state: results = st.session_state['results'] st.subheader("Generated Molecules") cols = st.columns(min(len(results), 5)) for i, res in enumerate(results): with cols[i % 5]: st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.") # --- FIX: Check the type of the output before displaying --- output = draw_molecule_from_structure(res) if isinstance(output, str): st.code(output) elif output is not None: st.image(output) else: st.warning("Could not draw.") with st.expander("Details"): st.json(res)