Aluode's picture
Update app.py
7059c33 verified
# molecular_demo2.py
# FINAL VERSION with robust display logic
import streamlit as st
import numpy as np
import time
from collections import defaultdict
import json
import io
try:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdMolDraw2D
RDKIT_AVAILABLE = True
from PIL import Image
except ImportError:
RDKIT_AVAILABLE = False
from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints
class SparsePhaseCalciumField3SAT:
def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045,
prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300):
np.random.seed(seed)
self.N, self.M, self.clauses = N_vars, len(clauses), clauses
self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT
self.drive, self.max_steps = drive, solver_steps
self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M)
self.W = defaultdict(dict)
for _ in range(min(self.N * 2, 20000)):
i, j = np.random.randint(0, self.N, 2)
if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05)
self.history = {'satisfaction': []}
def get_assignment(self): return np.cos(self.phases) > 0
def evaluate_clause(self, clause, assignment):
for lit in clause:
idx = abs(lit) - 1
if idx >= self.N: continue
val = assignment[idx]
if (lit > 0 and val) or (lit < 0 and not val): return True
return False
def compute_satisfaction(self, assignment=None):
if assignment is None: assignment = self.get_assignment()
if self.M == 0: return 1.0
return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M
def step(self):
dphi, assignment = np.zeros(self.N), self.get_assignment()
for idx, clause in enumerate(self.clauses):
if not self.evaluate_clause(clause, assignment):
self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0)
lit = clause[np.random.randint(len(clause))]
idx_var = abs(lit) - 1
if idx_var >= self.N: continue
target = 0.0 if lit > 0 else np.pi
dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var])
for i in self.W:
for j, w in self.W[i].items():
p_diff = self.phases[j] - self.phases[i]
dphi[i] += self.K * w * np.sin(p_diff)
dphi[j] -= self.K * w * np.sin(p_diff)
dphi += self.noise * np.random.randn(self.N)
self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi)
if np.random.rand() < 0.1:
for _ in range(20):
i, j = np.random.randint(0, self.N, 2)
if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98:
self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta)
if self.W:
s = np.random.choice(list(self.W.keys()))
if self.W[s]:
t = np.random.choice(list(self.W[s].keys()))
self.W[s][t] *= (1 - self.prune_rate)
if self.W[s][t] < 0.01: del self.W[s][t]
self.history['satisfaction'].append(self.compute_satisfaction())
def draw_molecule_from_structure(s_dict):
"""Draw raw graph with atom labels, falling back to text if needed."""
atoms = s_dict.get('atoms', [])
bonds = s_dict.get('bonds', [])
# Define the text fallback first
def get_text_fallback():
if not atoms: return "No atoms in structure."
adj = {a['id']: [] for a in atoms}
for b in bonds:
# The source of the KeyError is here, so we add a check
if b['from'] in adj and b['to'] in adj:
adj[b['from']].append(b['to'])
adj[b['to']].append(b['from'])
lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj.get(a['id'], [])))}" for a in atoms]
return "\n".join(lines)
if not RDKIT_AVAILABLE:
return get_text_fallback()
try:
mol = Chem.RWMol()
atom_map = {info['id']: mol.AddAtom(Chem.Atom(info['element'])) for info in atoms}
for bond in bonds:
a, b = bond['from'], bond['to']
if a in atom_map and b in atom_map:
mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE)
if mol.GetNumAtoms() == 0:
return None
rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()}
drawer = rdMolDraw2D.MolDraw2DCairo(300, 300)
opts = drawer.drawOptions()
for idx in range(mol.GetNumAtoms()):
original_id = rdkit_idx_to_original_id.get(idx, '?')
symbol = mol.GetAtomWithIdx(idx).GetSymbol()
opts.atomLabels[idx] = f"{original_id}:{symbol}"
rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol)
drawer.FinishDrawing()
png = drawer.GetDrawingText()
return Image.open(io.BytesIO(png))
except Exception:
return get_text_fallback() # Return text on any RDKit error
st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬")
st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True)
st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True)
st.markdown('<div class="sub-header">Generate molecules satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True)
st.sidebar.header("Constraint Configuration")
st.sidebar.subheader("Chemical Properties")
aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1)
max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10)
forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], [])
st.sidebar.subheader("Additional Constraints")
min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.")
synthesizable = st.sidebar.checkbox("Synthesizable", value=False)
max_atoms = 30
st.sidebar.subheader("Solver Parameters")
n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5)
solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300)
drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0)
if st.sidebar.button("🧬 Generate Molecules", type="primary"):
with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."):
try:
constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"]
if min_atoms > 0:
constraints_list.append(f"min_atoms >= {min_atoms}")
for group in forbidden_groups: constraints_list.append(f"NOT {group}")
if synthesizable: constraints_list.append("synthesizable")
constraints = parse_constraints(constraints_list)
encoder = MolecularConstraintEncoder(max_atoms=max_atoms)
clauses, n_vars = encoder.encode_constraints(constraints)
st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.")
results = []
progress_bar = st.progress(0, text="Generating molecules...")
for i in range(n_molecules):
solver = SparsePhaseCalciumField3SAT(
N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i,
drive=drive_strength, solver_steps=solver_steps
)
for _ in range(solver_steps): solver.step()
assignment = solver.get_assignment()
structure = encoder.decode_solution(assignment)
structure['satisfaction'] = solver.compute_satisfaction()
structure['molecule_id'] = i + 1
results.append(structure)
progress_bar.progress((i + 1) / n_molecules)
st.session_state['results'] = results
st.success(f"Successfully generated {n_molecules} molecular structures!")
except Exception as e:
st.error(f"An error occurred: {e}")
import traceback
st.code(traceback.format_exc())
if 'results' in st.session_state:
results = st.session_state['results']
st.subheader("Generated Molecules")
cols = st.columns(min(len(results), 5))
for i, res in enumerate(results):
with cols[i % 5]:
st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.")
# --- FIX: Check the type of the output before displaying ---
output = draw_molecule_from_structure(res)
if isinstance(output, str):
st.code(output)
elif output is not None:
st.image(output)
else:
st.warning("Could not draw.")
with st.expander("Details"):
st.json(res)