Update app.py
Browse files
app.py
CHANGED
|
@@ -1,191 +1,202 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
-
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
from rdkit
|
| 13 |
-
from rdkit.Chem import
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
self.
|
| 27 |
-
self.
|
| 28 |
-
self.
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
if
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
st.
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
st.sidebar.
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
st.
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
st.json(res)
|
|
|
|
| 1 |
+
# molecular_demo2.py
|
| 2 |
+
# FINAL VERSION with robust display logic
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import json
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from rdkit import Chem
|
| 13 |
+
from rdkit.Chem import Draw
|
| 14 |
+
from rdkit.Chem import rdMolDraw2D
|
| 15 |
+
RDKIT_AVAILABLE = True
|
| 16 |
+
from PIL import Image
|
| 17 |
+
except ImportError:
|
| 18 |
+
RDKIT_AVAILABLE = False
|
| 19 |
+
|
| 20 |
+
from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints
|
| 21 |
+
|
| 22 |
+
class SparsePhaseCalciumField3SAT:
|
| 23 |
+
def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045,
|
| 24 |
+
prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300):
|
| 25 |
+
np.random.seed(seed)
|
| 26 |
+
self.N, self.M, self.clauses = N_vars, len(clauses), clauses
|
| 27 |
+
self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT
|
| 28 |
+
self.drive, self.max_steps = drive, solver_steps
|
| 29 |
+
self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M)
|
| 30 |
+
self.W = defaultdict(dict)
|
| 31 |
+
for _ in range(min(self.N * 2, 20000)):
|
| 32 |
+
i, j = np.random.randint(0, self.N, 2)
|
| 33 |
+
if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05)
|
| 34 |
+
self.history = {'satisfaction': []}
|
| 35 |
+
|
| 36 |
+
def get_assignment(self): return np.cos(self.phases) > 0
|
| 37 |
+
|
| 38 |
+
def evaluate_clause(self, clause, assignment):
|
| 39 |
+
for lit in clause:
|
| 40 |
+
idx = abs(lit) - 1
|
| 41 |
+
if idx >= self.N: continue
|
| 42 |
+
val = assignment[idx]
|
| 43 |
+
if (lit > 0 and val) or (lit < 0 and not val): return True
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
def compute_satisfaction(self, assignment=None):
|
| 47 |
+
if assignment is None: assignment = self.get_assignment()
|
| 48 |
+
if self.M == 0: return 1.0
|
| 49 |
+
return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M
|
| 50 |
+
|
| 51 |
+
def step(self):
|
| 52 |
+
dphi, assignment = np.zeros(self.N), self.get_assignment()
|
| 53 |
+
for idx, clause in enumerate(self.clauses):
|
| 54 |
+
if not self.evaluate_clause(clause, assignment):
|
| 55 |
+
self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0)
|
| 56 |
+
lit = clause[np.random.randint(len(clause))]
|
| 57 |
+
idx_var = abs(lit) - 1
|
| 58 |
+
if idx_var >= self.N: continue
|
| 59 |
+
target = 0.0 if lit > 0 else np.pi
|
| 60 |
+
dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var])
|
| 61 |
+
for i in self.W:
|
| 62 |
+
for j, w in self.W[i].items():
|
| 63 |
+
p_diff = self.phases[j] - self.phases[i]
|
| 64 |
+
dphi[i] += self.K * w * np.sin(p_diff)
|
| 65 |
+
dphi[j] -= self.K * w * np.sin(p_diff)
|
| 66 |
+
dphi += self.noise * np.random.randn(self.N)
|
| 67 |
+
self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi)
|
| 68 |
+
if np.random.rand() < 0.1:
|
| 69 |
+
for _ in range(20):
|
| 70 |
+
i, j = np.random.randint(0, self.N, 2)
|
| 71 |
+
if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98:
|
| 72 |
+
self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta)
|
| 73 |
+
if self.W:
|
| 74 |
+
s = np.random.choice(list(self.W.keys()))
|
| 75 |
+
if self.W[s]:
|
| 76 |
+
t = np.random.choice(list(self.W[s].keys()))
|
| 77 |
+
self.W[s][t] *= (1 - self.prune_rate)
|
| 78 |
+
if self.W[s][t] < 0.01: del self.W[s][t]
|
| 79 |
+
self.history['satisfaction'].append(self.compute_satisfaction())
|
| 80 |
+
|
| 81 |
+
def draw_molecule_from_structure(s_dict):
|
| 82 |
+
"""Draw raw graph with atom labels, falling back to text if needed."""
|
| 83 |
+
atoms = s_dict.get('atoms', [])
|
| 84 |
+
bonds = s_dict.get('bonds', [])
|
| 85 |
+
|
| 86 |
+
# Define the text fallback first
|
| 87 |
+
def get_text_fallback():
|
| 88 |
+
if not atoms: return "No atoms in structure."
|
| 89 |
+
adj = {a['id']: [] for a in atoms}
|
| 90 |
+
for b in bonds:
|
| 91 |
+
# The source of the KeyError is here, so we add a check
|
| 92 |
+
if b['from'] in adj and b['to'] in adj:
|
| 93 |
+
adj[b['from']].append(b['to'])
|
| 94 |
+
adj[b['to']].append(b['from'])
|
| 95 |
+
lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj.get(a['id'], [])))}" for a in atoms]
|
| 96 |
+
return "\n".join(lines)
|
| 97 |
+
|
| 98 |
+
if not RDKIT_AVAILABLE:
|
| 99 |
+
return get_text_fallback()
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
mol = Chem.RWMol()
|
| 103 |
+
atom_map = {info['id']: mol.AddAtom(Chem.Atom(info['element'])) for info in atoms}
|
| 104 |
+
for bond in bonds:
|
| 105 |
+
a, b = bond['from'], bond['to']
|
| 106 |
+
if a in atom_map and b in atom_map:
|
| 107 |
+
mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE)
|
| 108 |
+
if mol.GetNumAtoms() == 0:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()}
|
| 112 |
+
drawer = rdMolDraw2D.MolDraw2DCairo(300, 300)
|
| 113 |
+
opts = drawer.drawOptions()
|
| 114 |
+
for idx in range(mol.GetNumAtoms()):
|
| 115 |
+
original_id = rdkit_idx_to_original_id.get(idx, '?')
|
| 116 |
+
symbol = mol.GetAtomWithIdx(idx).GetSymbol()
|
| 117 |
+
opts.atomLabels[idx] = f"{original_id}:{symbol}"
|
| 118 |
+
rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol)
|
| 119 |
+
drawer.FinishDrawing()
|
| 120 |
+
png = drawer.GetDrawingText()
|
| 121 |
+
return Image.open(io.BytesIO(png))
|
| 122 |
+
except Exception:
|
| 123 |
+
return get_text_fallback() # Return text on any RDKit error
|
| 124 |
+
|
| 125 |
+
st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬")
|
| 126 |
+
st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True)
|
| 127 |
+
st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True)
|
| 128 |
+
st.markdown('<div class="sub-header">Generate molecules satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True)
|
| 129 |
+
|
| 130 |
+
st.sidebar.header("Constraint Configuration")
|
| 131 |
+
st.sidebar.subheader("Chemical Properties")
|
| 132 |
+
aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1)
|
| 133 |
+
max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10)
|
| 134 |
+
forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], [])
|
| 135 |
+
|
| 136 |
+
st.sidebar.subheader("Additional Constraints")
|
| 137 |
+
min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.")
|
| 138 |
+
synthesizable = st.sidebar.checkbox("Synthesizable", value=False)
|
| 139 |
+
max_atoms = 30
|
| 140 |
+
|
| 141 |
+
st.sidebar.subheader("Solver Parameters")
|
| 142 |
+
n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5)
|
| 143 |
+
solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300)
|
| 144 |
+
drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0)
|
| 145 |
+
|
| 146 |
+
if st.sidebar.button("🧬 Generate Molecules", type="primary"):
|
| 147 |
+
with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."):
|
| 148 |
+
try:
|
| 149 |
+
constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"]
|
| 150 |
+
if min_atoms > 0:
|
| 151 |
+
constraints_list.append(f"min_atoms >= {min_atoms}")
|
| 152 |
+
for group in forbidden_groups: constraints_list.append(f"NOT {group}")
|
| 153 |
+
if synthesizable: constraints_list.append("synthesizable")
|
| 154 |
+
|
| 155 |
+
constraints = parse_constraints(constraints_list)
|
| 156 |
+
encoder = MolecularConstraintEncoder(max_atoms=max_atoms)
|
| 157 |
+
clauses, n_vars = encoder.encode_constraints(constraints)
|
| 158 |
+
|
| 159 |
+
st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.")
|
| 160 |
+
results = []
|
| 161 |
+
progress_bar = st.progress(0, text="Generating molecules...")
|
| 162 |
+
|
| 163 |
+
for i in range(n_molecules):
|
| 164 |
+
solver = SparsePhaseCalciumField3SAT(
|
| 165 |
+
N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i,
|
| 166 |
+
drive=drive_strength, solver_steps=solver_steps
|
| 167 |
+
)
|
| 168 |
+
for _ in range(solver_steps): solver.step()
|
| 169 |
+
|
| 170 |
+
assignment = solver.get_assignment()
|
| 171 |
+
structure = encoder.decode_solution(assignment)
|
| 172 |
+
structure['satisfaction'] = solver.compute_satisfaction()
|
| 173 |
+
structure['molecule_id'] = i + 1
|
| 174 |
+
results.append(structure)
|
| 175 |
+
progress_bar.progress((i + 1) / n_molecules)
|
| 176 |
+
|
| 177 |
+
st.session_state['results'] = results
|
| 178 |
+
st.success(f"Successfully generated {n_molecules} molecular structures!")
|
| 179 |
+
except Exception as e:
|
| 180 |
+
st.error(f"An error occurred: {e}")
|
| 181 |
+
import traceback
|
| 182 |
+
st.code(traceback.format_exc())
|
| 183 |
+
|
| 184 |
+
if 'results' in st.session_state:
|
| 185 |
+
results = st.session_state['results']
|
| 186 |
+
st.subheader("Generated Molecules")
|
| 187 |
+
cols = st.columns(min(len(results), 5))
|
| 188 |
+
for i, res in enumerate(results):
|
| 189 |
+
with cols[i % 5]:
|
| 190 |
+
st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.")
|
| 191 |
+
|
| 192 |
+
# --- FIX: Check the type of the output before displaying ---
|
| 193 |
+
output = draw_molecule_from_structure(res)
|
| 194 |
+
if isinstance(output, str):
|
| 195 |
+
st.code(output)
|
| 196 |
+
elif output is not None:
|
| 197 |
+
st.image(output)
|
| 198 |
+
else:
|
| 199 |
+
st.warning("Could not draw.")
|
| 200 |
+
|
| 201 |
+
with st.expander("Details"):
|
| 202 |
st.json(res)
|