Aluode commited on
Commit
7059c33
·
verified ·
1 Parent(s): 0f19e3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -190
app.py CHANGED
@@ -1,191 +1,202 @@
1
- # app.py
2
-
3
- import streamlit as st
4
- import numpy as np
5
- import time
6
- from collections import defaultdict
7
- import json
8
- import io
9
-
10
- try:
11
- from rdkit import Chem
12
- from rdkit.Chem import Draw
13
- from rdkit.Chem import rdMolDraw2D
14
- RDKIT_AVAILABLE = True
15
- except ImportError:
16
- RDKIT_AVAILABLE = False
17
-
18
- from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints
19
-
20
- class SparsePhaseCalciumField3SAT:
21
- def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045,
22
- prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300):
23
- np.random.seed(seed)
24
- self.N, self.M, self.clauses = N_vars, len(clauses), clauses
25
- self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT
26
- self.drive, self.max_steps = drive, solver_steps
27
- self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M)
28
- self.W = defaultdict(dict)
29
- for _ in range(min(self.N * 2, 20000)):
30
- i, j = np.random.randint(0, self.N, 2)
31
- if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05)
32
- self.history = {'satisfaction': []}
33
-
34
- def get_assignment(self): return np.cos(self.phases) > 0
35
-
36
- def evaluate_clause(self, clause, assignment):
37
- for lit in clause:
38
- idx = abs(lit) - 1
39
- if idx >= self.N: continue
40
- val = assignment[idx]
41
- if (lit > 0 and val) or (lit < 0 and not val): return True
42
- return False
43
-
44
- def compute_satisfaction(self, assignment=None):
45
- if assignment is None: assignment = self.get_assignment()
46
- if self.M == 0: return 1.0
47
- return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M
48
-
49
- def step(self):
50
- dphi, assignment = np.zeros(self.N), self.get_assignment()
51
- for idx, clause in enumerate(self.clauses):
52
- if not self.evaluate_clause(clause, assignment):
53
- self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0)
54
- lit = clause[np.random.randint(len(clause))]
55
- idx_var = abs(lit) - 1
56
- if idx_var >= self.N: continue
57
- target = 0.0 if lit > 0 else np.pi
58
- dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var])
59
- for i in self.W:
60
- for j, w in self.W[i].items():
61
- p_diff = self.phases[j] - self.phases[i]
62
- dphi[i] += self.K * w * np.sin(p_diff)
63
- dphi[j] -= self.K * w * np.sin(p_diff)
64
- dphi += self.noise * np.random.randn(self.N)
65
- self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi)
66
- if np.random.rand() < 0.1:
67
- for _ in range(20):
68
- i, j = np.random.randint(0, self.N, 2)
69
- if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98:
70
- self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta)
71
- if self.W:
72
- s = np.random.choice(list(self.W.keys()))
73
- if self.W[s]:
74
- t = np.random.choice(list(self.W[s].keys()))
75
- self.W[s][t] *= (1 - self.prune_rate)
76
- if self.W[s][t] < 0.01: del self.W[s][t]
77
- self.history['satisfaction'].append(self.compute_satisfaction())
78
-
79
- def draw_molecule_from_structure(s_dict):
80
- if not RDKIT_AVAILABLE:
81
- atoms = s_dict.get('atoms', [])
82
- bonds = s_dict.get('bonds', [])
83
- if not atoms: return "No atoms to draw."
84
- adj = {a['id']: [] for a in atoms}
85
- for b in bonds:
86
- adj[b['from']].append(b['to'])
87
- adj[b['to']].append(b['from'])
88
- lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj[a['id']]))}" for a in atoms]
89
- return "\n".join(lines)
90
- try:
91
- mol = Chem.RWMol()
92
- atom_map = {}
93
- for info in s_dict.get('atoms', []):
94
- atom = Chem.Atom(info['element'])
95
- idx = mol.AddAtom(atom)
96
- atom_map[info['id']] = idx
97
- for bond in s_dict.get('bonds', []):
98
- a, b = bond['from'], bond['to']
99
- if a in atom_map and b in atom_map:
100
- mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE)
101
- if mol.GetNumAtoms() == 0: return None
102
- rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()}
103
- drawer = rdMolDraw2D.MolDraw2DCairo(300, 300)
104
- opts = drawer.drawOptions()
105
- for idx in range(mol.GetNumAtoms()):
106
- original_id = rdkit_idx_to_original_id.get(idx, '?')
107
- symbol = mol.GetAtomWithIdx(idx).GetSymbol()
108
- opts.atomLabels[idx] = f"{original_id}:{symbol}"
109
- rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol)
110
- drawer.FinishDrawing()
111
- png = drawer.GetDrawingText()
112
- from PIL import Image
113
- return Image.open(io.BytesIO(png))
114
- except Exception as e:
115
- return f"RDKit drawing failed: {e}"
116
-
117
- st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬")
118
- st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True)
119
- st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True)
120
- st.markdown('<div class="sub-header">Generate molecular graphs satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True)
121
-
122
- st.sidebar.header("Constraint Configuration")
123
- st.sidebar.subheader("Chemical Properties")
124
- aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1)
125
- max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10)
126
- forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], [])
127
-
128
- st.sidebar.subheader("Additional Constraints")
129
- min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.")
130
- synthesizable = st.sidebar.checkbox("Synthesizable", value=False)
131
- max_atoms = 30
132
-
133
- st.sidebar.subheader("Solver Parameters")
134
- n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5)
135
- solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300)
136
- drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0)
137
-
138
- if st.sidebar.button("🧬 Generate Molecules", type="primary"):
139
- with st.spinner("Encoding constraints → Solving 3-SAT → Decoding structures..."):
140
- try:
141
- constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"]
142
- if min_atoms > 0:
143
- constraints_list.append(f"min_atoms >= {min_atoms}")
144
- for group in forbidden_groups: constraints_list.append(f"NOT {group}")
145
- if synthesizable: constraints_list.append("synthesizable")
146
-
147
- constraints = parse_constraints(constraints_list)
148
- encoder = MolecularConstraintEncoder(max_atoms=max_atoms)
149
- clauses, n_vars = encoder.encode_constraints(constraints)
150
-
151
- st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.")
152
- results = []
153
- progress_bar = st.progress(0, text="Generating molecules...")
154
-
155
- for i in range(n_molecules):
156
- solver = SparsePhaseCalciumField3SAT(
157
- N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i,
158
- drive=drive_strength, solver_steps=solver_steps
159
- )
160
- for _ in range(solver_steps): solver.step()
161
-
162
- assignment = solver.get_assignment()
163
- structure = encoder.decode_solution(assignment)
164
- structure['satisfaction'] = solver.compute_satisfaction()
165
- structure['molecule_id'] = i + 1
166
- results.append(structure)
167
- progress_bar.progress((i + 1) / n_molecules)
168
-
169
- st.session_state['results'] = results
170
- st.success(f"Successfully generated {n_molecules} molecular structures!")
171
- except Exception as e:
172
- st.error(f"An error occurred: {e}")
173
- import traceback
174
- st.code(traceback.format_exc())
175
-
176
- if 'results' in st.session_state:
177
- results = st.session_state['results']
178
- st.subheader("Generated Molecules")
179
- cols = st.columns(min(len(results), 5))
180
- for i, res in enumerate(results):
181
- with cols[i % 5]:
182
- st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.")
183
- output = draw_molecule_from_structure(res)
184
- if isinstance(output, str):
185
- st.code(output)
186
- elif output is not None:
187
- st.image(output)
188
- else:
189
- st.warning("Could not draw.")
190
- with st.expander("Details"):
 
 
 
 
 
 
 
 
 
 
 
191
  st.json(res)
 
1
+ # molecular_demo2.py
2
+ # FINAL VERSION with robust display logic
3
+
4
+ import streamlit as st
5
+ import numpy as np
6
+ import time
7
+ from collections import defaultdict
8
+ import json
9
+ import io
10
+
11
+ try:
12
+ from rdkit import Chem
13
+ from rdkit.Chem import Draw
14
+ from rdkit.Chem import rdMolDraw2D
15
+ RDKIT_AVAILABLE = True
16
+ from PIL import Image
17
+ except ImportError:
18
+ RDKIT_AVAILABLE = False
19
+
20
+ from molecular_constraint_solver import MolecularConstraintEncoder, parse_constraints
21
+
22
+ class SparsePhaseCalciumField3SAT:
23
+ def __init__(self, N_vars, clauses, seed=42, K=0.87, eta=0.045,
24
+ prune_rate=0.005, noise=0.03, DT=0.003, drive=14.28, solver_steps=300):
25
+ np.random.seed(seed)
26
+ self.N, self.M, self.clauses = N_vars, len(clauses), clauses
27
+ self.K, self.eta, self.prune_rate, self.noise, self.DT = K, eta, prune_rate, noise, DT
28
+ self.drive, self.max_steps = drive, solver_steps
29
+ self.phases, self.clause_weights = np.random.uniform(0, 2 * np.pi, N_vars), np.ones(self.M)
30
+ self.W = defaultdict(dict)
31
+ for _ in range(min(self.N * 2, 20000)):
32
+ i, j = np.random.randint(0, self.N, 2)
33
+ if i != j: self.W[i][j] = np.random.uniform(0.01, 0.05)
34
+ self.history = {'satisfaction': []}
35
+
36
+ def get_assignment(self): return np.cos(self.phases) > 0
37
+
38
+ def evaluate_clause(self, clause, assignment):
39
+ for lit in clause:
40
+ idx = abs(lit) - 1
41
+ if idx >= self.N: continue
42
+ val = assignment[idx]
43
+ if (lit > 0 and val) or (lit < 0 and not val): return True
44
+ return False
45
+
46
+ def compute_satisfaction(self, assignment=None):
47
+ if assignment is None: assignment = self.get_assignment()
48
+ if self.M == 0: return 1.0
49
+ return sum(1 for c in self.clauses if self.evaluate_clause(c, assignment)) / self.M
50
+
51
+ def step(self):
52
+ dphi, assignment = np.zeros(self.N), self.get_assignment()
53
+ for idx, clause in enumerate(self.clauses):
54
+ if not self.evaluate_clause(clause, assignment):
55
+ self.clause_weights[idx] = min(self.clause_weights[idx] + 0.02, 5.0)
56
+ lit = clause[np.random.randint(len(clause))]
57
+ idx_var = abs(lit) - 1
58
+ if idx_var >= self.N: continue
59
+ target = 0.0 if lit > 0 else np.pi
60
+ dphi[idx_var] += self.drive * self.clause_weights[idx] * np.sin(target - self.phases[idx_var])
61
+ for i in self.W:
62
+ for j, w in self.W[i].items():
63
+ p_diff = self.phases[j] - self.phases[i]
64
+ dphi[i] += self.K * w * np.sin(p_diff)
65
+ dphi[j] -= self.K * w * np.sin(p_diff)
66
+ dphi += self.noise * np.random.randn(self.N)
67
+ self.phases = np.mod(self.phases + self.DT * dphi, 2 * np.pi)
68
+ if np.random.rand() < 0.1:
69
+ for _ in range(20):
70
+ i, j = np.random.randint(0, self.N, 2)
71
+ if i != j and np.cos(self.phases[i] - self.phases[j]) > 0.98:
72
+ self.W[i][j] = min(1.0, self.W[i].get(j, 0.0) + self.eta)
73
+ if self.W:
74
+ s = np.random.choice(list(self.W.keys()))
75
+ if self.W[s]:
76
+ t = np.random.choice(list(self.W[s].keys()))
77
+ self.W[s][t] *= (1 - self.prune_rate)
78
+ if self.W[s][t] < 0.01: del self.W[s][t]
79
+ self.history['satisfaction'].append(self.compute_satisfaction())
80
+
81
+ def draw_molecule_from_structure(s_dict):
82
+ """Draw raw graph with atom labels, falling back to text if needed."""
83
+ atoms = s_dict.get('atoms', [])
84
+ bonds = s_dict.get('bonds', [])
85
+
86
+ # Define the text fallback first
87
+ def get_text_fallback():
88
+ if not atoms: return "No atoms in structure."
89
+ adj = {a['id']: [] for a in atoms}
90
+ for b in bonds:
91
+ # The source of the KeyError is here, so we add a check
92
+ if b['from'] in adj and b['to'] in adj:
93
+ adj[b['from']].append(b['to'])
94
+ adj[b['to']].append(b['from'])
95
+ lines = [f"{a['id']:02d} {a['element']:>2} -> {', '.join(map(str, adj.get(a['id'], [])))}" for a in atoms]
96
+ return "\n".join(lines)
97
+
98
+ if not RDKIT_AVAILABLE:
99
+ return get_text_fallback()
100
+
101
+ try:
102
+ mol = Chem.RWMol()
103
+ atom_map = {info['id']: mol.AddAtom(Chem.Atom(info['element'])) for info in atoms}
104
+ for bond in bonds:
105
+ a, b = bond['from'], bond['to']
106
+ if a in atom_map and b in atom_map:
107
+ mol.AddBond(atom_map[a], atom_map[b], Chem.BondType.SINGLE)
108
+ if mol.GetNumAtoms() == 0:
109
+ return None
110
+
111
+ rdkit_idx_to_original_id = {v: k for k, v in atom_map.items()}
112
+ drawer = rdMolDraw2D.MolDraw2DCairo(300, 300)
113
+ opts = drawer.drawOptions()
114
+ for idx in range(mol.GetNumAtoms()):
115
+ original_id = rdkit_idx_to_original_id.get(idx, '?')
116
+ symbol = mol.GetAtomWithIdx(idx).GetSymbol()
117
+ opts.atomLabels[idx] = f"{original_id}:{symbol}"
118
+ rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol)
119
+ drawer.FinishDrawing()
120
+ png = drawer.GetDrawingText()
121
+ return Image.open(io.BytesIO(png))
122
+ except Exception:
123
+ return get_text_fallback() # Return text on any RDKit error
124
+
125
+ st.set_page_config(page_title="Molecular Constraint Solver", layout="wide", page_icon="🧬")
126
+ st.markdown("""<style>.main-header{font-size:3rem;color:#1f77b4;text-align:center}.sub-header{font-size:1.2rem;color:#666;text-align:center;margin-bottom:2rem}</style>""", unsafe_allow_html=True)
127
+ st.markdown('<div class="main-header">🧬 Molecular Constraint Solver</div>', unsafe_allow_html=True)
128
+ st.markdown('<div class="sub-header">Generate molecules satisfying hard constraints via neuromorphic 3-SAT solving</div>', unsafe_allow_html=True)
129
+
130
+ st.sidebar.header("Constraint Configuration")
131
+ st.sidebar.subheader("Chemical Properties")
132
+ aromatic_rings = st.sidebar.slider("Aromatic Rings", 0, 5, 1)
133
+ max_mw = st.sidebar.slider("Maximum Molecular Weight (Da)", 200, 700, 500, step=10)
134
+ forbidden_groups = st.sidebar.multiselect("Forbidden Functional Groups:", ['nitro', 'azide', 'peroxide'], [])
135
+
136
+ st.sidebar.subheader("Additional Constraints")
137
+ min_atoms = st.sidebar.slider("Minimum atom count", 0, 30, 10, help="Forces the molecule to have at least this many atoms.")
138
+ synthesizable = st.sidebar.checkbox("Synthesizable", value=False)
139
+ max_atoms = 30
140
+
141
+ st.sidebar.subheader("Solver Parameters")
142
+ n_molecules = st.sidebar.slider("Number of molecules to generate", 1, 50, 5)
143
+ solver_steps = st.sidebar.slider("Solver Steps", 50, 1000, 300)
144
+ drive_strength = st.sidebar.slider("Drive Strength", 10.0, 100.0, 75.0, step=5.0)
145
+
146
+ if st.sidebar.button("🧬 Generate Molecules", type="primary"):
147
+ with st.spinner("Encoding constraints Solving 3-SAT → Decoding structures..."):
148
+ try:
149
+ constraints_list = [f"aromatic_rings == {aromatic_rings}", f"molecular_weight < {max_mw}"]
150
+ if min_atoms > 0:
151
+ constraints_list.append(f"min_atoms >= {min_atoms}")
152
+ for group in forbidden_groups: constraints_list.append(f"NOT {group}")
153
+ if synthesizable: constraints_list.append("synthesizable")
154
+
155
+ constraints = parse_constraints(constraints_list)
156
+ encoder = MolecularConstraintEncoder(max_atoms=max_atoms)
157
+ clauses, n_vars = encoder.encode_constraints(constraints)
158
+
159
+ st.info(f"Generated a SAT problem with {n_vars} variables and {len(clauses)} clauses.")
160
+ results = []
161
+ progress_bar = st.progress(0, text="Generating molecules...")
162
+
163
+ for i in range(n_molecules):
164
+ solver = SparsePhaseCalciumField3SAT(
165
+ N_vars=n_vars, clauses=clauses, seed=int(time.time()) + i,
166
+ drive=drive_strength, solver_steps=solver_steps
167
+ )
168
+ for _ in range(solver_steps): solver.step()
169
+
170
+ assignment = solver.get_assignment()
171
+ structure = encoder.decode_solution(assignment)
172
+ structure['satisfaction'] = solver.compute_satisfaction()
173
+ structure['molecule_id'] = i + 1
174
+ results.append(structure)
175
+ progress_bar.progress((i + 1) / n_molecules)
176
+
177
+ st.session_state['results'] = results
178
+ st.success(f"Successfully generated {n_molecules} molecular structures!")
179
+ except Exception as e:
180
+ st.error(f"An error occurred: {e}")
181
+ import traceback
182
+ st.code(traceback.format_exc())
183
+
184
+ if 'results' in st.session_state:
185
+ results = st.session_state['results']
186
+ st.subheader("Generated Molecules")
187
+ cols = st.columns(min(len(results), 5))
188
+ for i, res in enumerate(results):
189
+ with cols[i % 5]:
190
+ st.metric(f"Molecule {res['molecule_id']}", f"{res['satisfaction']:.1%} sat.")
191
+
192
+ # --- FIX: Check the type of the output before displaying ---
193
+ output = draw_molecule_from_structure(res)
194
+ if isinstance(output, str):
195
+ st.code(output)
196
+ elif output is not None:
197
+ st.image(output)
198
+ else:
199
+ st.warning("Could not draw.")
200
+
201
+ with st.expander("Details"):
202
  st.json(res)