File size: 9,376 Bytes
7059c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f19e3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# molecular_demo2.py
# FINAL VERSION with robust display logic

import streamlit as st
import numpy as np
import time
from collections import defaultdict
import json
import io

try:
    from rdkit import Chem
    from rdkit.Chem import Draw
    from rdkit.Chem import rdMolDraw2D
    RDKIT_AVAILABLE = True
    from PIL import Image
except ImportError:
    RDKIT_AVAILABLE = False

from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints

class SparsePhaseCalciumField3SAT:
    def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045, 
                 prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300):
        np.random.seed(seed)
        self.N, self.M, self.clauses = N_vars, len(clauses), clauses
        self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT
        self.drive, self.max_steps = drive, solver_steps
        self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M)
        self.W = defaultdict(dict)
        for _ in range(min(self.N * 2, 20000)):
            i, j = np.random.randint(0, self.N, 2)
            if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05)
        self.history = {'satisfaction': []}
    
    def get_assignment(self): return np.cos(self.phases) > 0
    
    def evaluate_clause(self, clause, assignment):
        for lit in clause:
            idx = abs(lit) - 1
            if idx >= self.N: continue 
            val = assignment[idx]
            if (lit > 0 and val) or (lit < 0 and not val): return True
        return False
    
    def compute_satisfaction(self, assignment=None):
        if assignment is None: assignment = self.get_assignment()
        if self.M == 0: return 1.0
        return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M
    
    def step(self):
        dphi, assignment = np.zeros(self.N), self.get_assignment()
        for idx, clause in enumerate(self.clauses):
            if not self.evaluate_clause(clause, assignment):
                self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0)
                lit = clause[np.random.randint(len(clause))]
                idx_var = abs(lit) - 1
                if idx_var >= self.N: continue
                target = 0.0 if lit > 0 else np.pi
                dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var])
        for i in self.W:
            for j, w in self.W[i].items():
                p_diff = self.phases[j] - self.phases[i]
                dphi[i] += self.K * w * np.sin(p_diff)
                dphi[j] -= self.K * w * np.sin(p_diff)
        dphi += self.noise * np.random.randn(self.N)
        self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi)
        if np.random.rand() < 0.1:
            for _ in range(20):
                i, j = np.random.randint(0, self.N, 2)
                if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98:
                    self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta)
            if self.W:
                s = np.random.choice(list(self.W.keys()))
                if self.W[s]:
                    t = np.random.choice(list(self.W[s].keys()))
                    self.W[s][t] *= (1 - self.prune_rate)
                    if self.W[s][t] < 0.01: del self.W[s][t]
        self.history['satisfaction'].append(self.compute_satisfaction())

def draw_molecule_from_structure(s_dict):
    """Draw raw graph with atom labels, falling back to text if needed."""
    atoms = s_dict.get('atoms', [])
    bonds = s_dict.get('bonds', [])
    
    # Define the text fallback first
    def get_text_fallback():
        if not atoms: return "No atoms in structure."
        adj = {a['id']: [] for a in atoms}
        for b in bonds:
            # The source of the KeyError is here, so we add a check
            if b['from'] in adj and b['to'] in adj:
                adj[b['from']].append(b['to'])
                adj[b['to']].append(b['from'])
        lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj.get(a['id'], [])))}" for a in atoms]
        return "\n".join(lines)

    if not RDKIT_AVAILABLE:
        return get_text_fallback()

    try:
        mol = Chem.RWMol()
        atom_map = {info['id']: mol.AddAtom(Chem.Atom(info['element'])) for info in atoms}
        for bond in bonds:
            a, b = bond['from'], bond['to']
            if a in atom_map and b in atom_map:
                mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE)
        if mol.GetNumAtoms() == 0:
            return None

        rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()}
        drawer = rdMolDraw2D.MolDraw2DCairo(300, 300)
        opts = drawer.drawOptions()
        for idx in range(mol.GetNumAtoms()):
            original_id = rdkit_idx_to_original_id.get(idx, '?')
            symbol = mol.GetAtomWithIdx(idx).GetSymbol()
            opts.atomLabels[idx] = f"{original_id}:{symbol}"
        rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol)
        drawer.FinishDrawing()
        png = drawer.GetDrawingText()
        return Image.open(io.BytesIO(png))
    except Exception:
        return get_text_fallback() # Return text on any RDKit error

st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬")
st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True)
st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True)
st.markdown('<div class="sub-header">Generate molecules satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True)

st.sidebar.header("Constraint Configuration")
st.sidebar.subheader("Chemical Properties")
aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1)
max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10)
forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], [])

st.sidebar.subheader("Additional Constraints")
min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.")
synthesizable = st.sidebar.checkbox("Synthesizable", value=False)
max_atoms = 30 

st.sidebar.subheader("Solver Parameters")
n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5)
solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300)
drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0)

if st.sidebar.button("🧬 Generate Molecules", type="primary"):
    with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."):
        try:
            constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"]
            if min_atoms > 0:
                constraints_list.append(f"min_atoms >= {min_atoms}")
            for group in forbidden_groups: constraints_list.append(f"NOT {group}")
            if synthesizable: constraints_list.append("synthesizable")
            
            constraints = parse_constraints(constraints_list)
            encoder = MolecularConstraintEncoder(max_atoms=max_atoms)
            clauses, n_vars = encoder.encode_constraints(constraints)
            
            st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.")
            results = []
            progress_bar = st.progress(0, text="Generating molecules...")

            for i in range(n_molecules):
                solver = SparsePhaseCalciumField3SAT(
                    N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i,
                    drive=drive_strength, solver_steps=solver_steps
                )
                for _ in range(solver_steps): solver.step()
                
                assignment = solver.get_assignment()
                structure = encoder.decode_solution(assignment)
                structure['satisfaction'] = solver.compute_satisfaction()
                structure['molecule_id'] = i + 1
                results.append(structure)
                progress_bar.progress((i + 1) / n_molecules)

            st.session_state['results'] = results
            st.success(f"Successfully generated {n_molecules} molecular structures!")
        except Exception as e:
            st.error(f"An error occurred: {e}")
            import traceback
            st.code(traceback.format_exc())

if 'results' in st.session_state:
    results = st.session_state['results']
    st.subheader("Generated Molecules")
    cols = st.columns(min(len(results), 5))
    for i, res in enumerate(results):
        with cols[i % 5]:
            st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.")
            
            # --- FIX: Check the type of the output before displaying ---
            output = draw_molecule_from_structure(res)
            if isinstance(output, str):
                st.code(output)
            elif output is not None:
                st.image(output)
            else:
                st.warning("Could not draw.")
                
            with st.expander("Details"):
                st.json(res)