nabilyasini commited on
Commit
2f60ca9
·
verified ·
1 Parent(s): ac3b159

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +833 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ StereoGNN-BBB: Blood-Brain Barrier Permeability Predictor
3
+ State-of-the-Art Model: AUC 0.9612 (External Validation on B3DB)
4
+
5
+ Author: Nabil Yasini-Ardekani
6
+ GitHub: https://github.com/abinittio
7
+
8
+ Streamlit Cloud Deployment Version - Self-Contained
9
+ """
10
+
11
+ import streamlit as st
12
+ import pandas as pd
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ from pathlib import Path
17
+ from datetime import datetime
18
+ import json
19
+ import base64
20
+ import io
21
+ import os
22
+
23
+ # Page config - MUST be first Streamlit command
24
+ st.set_page_config(
25
+ page_title="StereoGNN-BBB | BBB Predictor",
26
+ page_icon="🧠",
27
+ layout="wide",
28
+ initial_sidebar_state="expanded"
29
+ )
30
+
31
+ # RDKit imports
32
+ try:
33
+ from rdkit import Chem
34
+ from rdkit.Chem import Descriptors, AllChem
35
+ from rdkit.Chem.Draw import rdMolDraw2D
36
+ from rdkit.Chem import rdMolDescriptors
37
+ from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
38
+ RDKIT_AVAILABLE = True
39
+ except ImportError:
40
+ RDKIT_AVAILABLE = False
41
+ st.error("RDKit not available. Please install: pip install rdkit")
42
+
43
+ # PyTorch Geometric imports
44
+ try:
45
+ from torch_geometric.nn import GATv2Conv, TransformerConv, global_mean_pool, global_max_pool
46
+ from torch_geometric.data import Data
47
+ TORCH_GEOMETRIC_AVAILABLE = True
48
+ except ImportError:
49
+ TORCH_GEOMETRIC_AVAILABLE = False
50
+
51
+ # Custom CSS
52
+ st.markdown("""
53
+ <style>
54
+ .main-header {
55
+ font-size: 2.5rem;
56
+ font-weight: 700;
57
+ text-align: center;
58
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
59
+ -webkit-background-clip: text;
60
+ -webkit-text-fill-color: transparent;
61
+ margin-bottom: 0.3rem;
62
+ }
63
+ .sub-header {
64
+ text-align: center;
65
+ color: #6c757d;
66
+ font-size: 1rem;
67
+ margin-bottom: 1.5rem;
68
+ }
69
+ .prediction-card {
70
+ padding: 1.5rem;
71
+ border-radius: 12px;
72
+ text-align: center;
73
+ margin: 0.5rem 0;
74
+ }
75
+ .prediction-positive {
76
+ background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
77
+ color: white;
78
+ }
79
+ .prediction-negative {
80
+ background: linear-gradient(135deg, #ee0979 0%, #ff6a00 100%);
81
+ color: white;
82
+ }
83
+ .prediction-moderate {
84
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
85
+ color: white;
86
+ }
87
+ .metric-box {
88
+ background: #f8f9fa;
89
+ padding: 1rem;
90
+ border-radius: 8px;
91
+ border-left: 3px solid #667eea;
92
+ margin: 0.3rem 0;
93
+ }
94
+ .info-box {
95
+ background: #e7f3ff;
96
+ padding: 1rem;
97
+ border-radius: 8px;
98
+ border-left: 3px solid #0066cc;
99
+ margin: 0.5rem 0;
100
+ }
101
+ </style>
102
+ """, unsafe_allow_html=True)
103
+
104
+
105
+ # ============================================================================
106
+ # MODEL ARCHITECTURE (Self-contained)
107
+ # ============================================================================
108
+ if TORCH_GEOMETRIC_AVAILABLE:
109
+ class StereoAwareEncoder(nn.Module):
110
+ """Stereo-aware molecular encoder using GATv2 + Transformer."""
111
+
112
+ def __init__(self, node_features=21, hidden_dim=128, num_layers=4, heads=4, dropout=0.1):
113
+ super().__init__()
114
+ self.node_features = node_features
115
+ self.hidden_dim = hidden_dim
116
+
117
+ # Input projection
118
+ self.input_proj = nn.Sequential(
119
+ nn.Linear(node_features, hidden_dim),
120
+ nn.LayerNorm(hidden_dim),
121
+ nn.ReLU(),
122
+ nn.Dropout(dropout)
123
+ )
124
+
125
+ # GATv2 layers
126
+ self.gat_layers = nn.ModuleList()
127
+ self.gat_norms = nn.ModuleList()
128
+
129
+ for i in range(num_layers):
130
+ in_channels = hidden_dim
131
+ out_channels = hidden_dim // heads
132
+ self.gat_layers.append(
133
+ GATv2Conv(in_channels, out_channels, heads=heads, dropout=dropout, add_self_loops=True)
134
+ )
135
+ self.gat_norms.append(nn.LayerNorm(hidden_dim))
136
+
137
+ # Transformer layer
138
+ self.transformer = TransformerConv(hidden_dim, hidden_dim // heads, heads=heads, dropout=dropout)
139
+ self.transformer_norm = nn.LayerNorm(hidden_dim)
140
+
141
+ self.dropout = nn.Dropout(dropout)
142
+
143
+ def forward(self, x, edge_index, batch):
144
+ x = self.input_proj(x)
145
+
146
+ for gat, norm in zip(self.gat_layers, self.gat_norms):
147
+ residual = x
148
+ x = gat(x, edge_index)
149
+ x = norm(x + residual)
150
+ x = self.dropout(x)
151
+
152
+ residual = x
153
+ x = self.transformer(x, edge_index)
154
+ x = self.transformer_norm(x + residual)
155
+
156
+ x_mean = global_mean_pool(x, batch)
157
+ x_max = global_max_pool(x, batch)
158
+
159
+ return torch.cat([x_mean, x_max], dim=1)
160
+
161
+
162
+ class BBBClassifier(nn.Module):
163
+ """BBB classifier with stereo encoder."""
164
+
165
+ def __init__(self, encoder, hidden_dim=128):
166
+ super().__init__()
167
+ self.encoder = encoder
168
+ self.classifier = nn.Sequential(
169
+ nn.Linear(hidden_dim * 2, hidden_dim),
170
+ nn.BatchNorm1d(hidden_dim),
171
+ nn.ReLU(),
172
+ nn.Dropout(0.3),
173
+ nn.Linear(hidden_dim, hidden_dim // 2),
174
+ nn.ReLU(),
175
+ nn.Dropout(0.2),
176
+ nn.Linear(hidden_dim // 2, 1)
177
+ )
178
+
179
+ def forward(self, x, edge_index, batch):
180
+ graph_embed = self.encoder(x, edge_index, batch)
181
+ return self.classifier(graph_embed)
182
+
183
+
184
+ # ============================================================================
185
+ # MOLECULAR FEATURIZATION
186
+ # ============================================================================
187
+ def get_atom_features(atom):
188
+ """Generate 21-dimensional atom features including stereochemistry."""
189
+ features = []
190
+
191
+ # Atomic number (one-hot, common atoms)
192
+ atom_types = [6, 7, 8, 9, 15, 16, 17, 35, 53] # C, N, O, F, P, S, Cl, Br, I
193
+ atom_num = atom.GetAtomicNum()
194
+ features.extend([1 if atom_num == t else 0 for t in atom_types])
195
+
196
+ # Degree (0-5)
197
+ features.append(min(atom.GetDegree(), 5) / 5.0)
198
+
199
+ # Formal charge
200
+ features.append((atom.GetFormalCharge() + 2) / 4.0)
201
+
202
+ # Hybridization
203
+ hyb = atom.GetHybridization()
204
+ hyb_types = [Chem.rdchem.HybridizationType.SP,
205
+ Chem.rdchem.HybridizationType.SP2,
206
+ Chem.rdchem.HybridizationType.SP3]
207
+ features.extend([1 if hyb == h else 0 for h in hyb_types])
208
+
209
+ # Aromaticity
210
+ features.append(1 if atom.GetIsAromatic() else 0)
211
+
212
+ # In ring
213
+ features.append(1 if atom.IsInRing() else 0)
214
+
215
+ # Stereochemistry features (6 features)
216
+ chiral_tag = atom.GetChiralTag()
217
+ features.append(1 if chiral_tag != Chem.rdchem.ChiralType.CHI_UNSPECIFIED else 0)
218
+ features.append(1 if chiral_tag == Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW else 0)
219
+ features.append(1 if chiral_tag == Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW else 0)
220
+
221
+ # E/Z stereo (from bonds)
222
+ has_ez = False
223
+ is_e = False
224
+ is_z = False
225
+ for bond in atom.GetBonds():
226
+ stereo = bond.GetStereo()
227
+ if stereo in [Chem.rdchem.BondStereo.STEREOE, Chem.rdchem.BondStereo.STEREOZ]:
228
+ has_ez = True
229
+ if stereo == Chem.rdchem.BondStereo.STEREOE:
230
+ is_e = True
231
+ else:
232
+ is_z = True
233
+ features.extend([1 if has_ez else 0, 1 if is_e else 0, 1 if is_z else 0])
234
+
235
+ return features
236
+
237
+
238
+ def smiles_to_graph(smiles):
239
+ """Convert SMILES to PyG Data object with 21-dim features."""
240
+ if not RDKIT_AVAILABLE or not TORCH_GEOMETRIC_AVAILABLE:
241
+ return None
242
+
243
+ mol = Chem.MolFromSmiles(smiles)
244
+ if mol is None:
245
+ return None
246
+
247
+ atom_features = []
248
+ for atom in mol.GetAtoms():
249
+ atom_features.append(get_atom_features(atom))
250
+
251
+ x = torch.tensor(atom_features, dtype=torch.float)
252
+
253
+ edge_index = []
254
+ for bond in mol.GetBonds():
255
+ i = bond.GetBeginAtomIdx()
256
+ j = bond.GetEndAtomIdx()
257
+ edge_index.extend([[i, j], [j, i]])
258
+
259
+ if len(edge_index) == 0:
260
+ edge_index = torch.zeros((2, 0), dtype=torch.long)
261
+ else:
262
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
263
+
264
+ return Data(x=x, edge_index=edge_index)
265
+
266
+
267
+ # ============================================================================
268
+ # DESCRIPTOR-BASED PREDICTOR (Fallback when no model weights)
269
+ # ============================================================================
270
+ class DescriptorBBBPredictor:
271
+ """
272
+ Descriptor-based BBB predictor using optimized rules.
273
+ Based on published BBB penetration rules and trained coefficients.
274
+ """
275
+
276
+ def __init__(self):
277
+ # Optimized coefficients from training on BBBP dataset
278
+ self.coefficients = {
279
+ 'intercept': 0.65,
280
+ 'mw': -0.0012, # Negative: higher MW = less penetration
281
+ 'logp': 0.08, # Positive: higher logP = more penetration
282
+ 'tpsa': -0.008, # Negative: higher TPSA = less penetration
283
+ 'hbd': -0.12, # Negative: more H-donors = less penetration
284
+ 'hba': -0.05, # Negative: more H-acceptors = less penetration
285
+ 'rotatable': -0.02, # Negative: more flexibility = less penetration
286
+ 'aromatic_rings': 0.05,
287
+ 'n_atoms': -0.005,
288
+ }
289
+
290
+ def predict(self, smiles):
291
+ """Predict BBB permeability from SMILES."""
292
+ mol = Chem.MolFromSmiles(smiles)
293
+ if mol is None:
294
+ return None, "Invalid SMILES"
295
+
296
+ # Calculate descriptors
297
+ mw = Descriptors.MolWt(mol)
298
+ logp = Descriptors.MolLogP(mol)
299
+ tpsa = Descriptors.TPSA(mol)
300
+ hbd = Descriptors.NumHDonors(mol)
301
+ hba = Descriptors.NumHAcceptors(mol)
302
+ rotatable = Descriptors.NumRotatableBonds(mol)
303
+ aromatic_rings = Descriptors.NumAromaticRings(mol)
304
+ n_atoms = mol.GetNumAtoms()
305
+
306
+ # Calculate score
307
+ score = self.coefficients['intercept']
308
+ score += self.coefficients['mw'] * (mw - 300) / 100
309
+ score += self.coefficients['logp'] * (logp - 2)
310
+ score += self.coefficients['tpsa'] * (tpsa - 60)
311
+ score += self.coefficients['hbd'] * hbd
312
+ score += self.coefficients['hba'] * (hba - 4)
313
+ score += self.coefficients['rotatable'] * rotatable
314
+ score += self.coefficients['aromatic_rings'] * aromatic_rings
315
+ score += self.coefficients['n_atoms'] * (n_atoms - 25)
316
+
317
+ # Sigmoid to get probability
318
+ prob = 1 / (1 + np.exp(-score * 2))
319
+
320
+ # Clamp to reasonable range
321
+ prob = max(0.05, min(0.95, prob))
322
+
323
+ return prob, None
324
+
325
+
326
+ # ============================================================================
327
+ # STEREOISOMER ENUMERATION
328
+ # ============================================================================
329
+ def enumerate_stereoisomers(smiles, max_isomers=16):
330
+ """Enumerate all stereoisomers for a molecule."""
331
+ if not RDKIT_AVAILABLE:
332
+ return [smiles]
333
+
334
+ mol = Chem.MolFromSmiles(smiles)
335
+ if mol is None:
336
+ return [smiles]
337
+
338
+ opts = StereoEnumerationOptions(
339
+ tryEmbedding=True,
340
+ unique=True,
341
+ maxIsomers=max_isomers
342
+ )
343
+
344
+ try:
345
+ isomers = list(EnumerateStereoisomers(mol, options=opts))
346
+ if len(isomers) == 0:
347
+ return [smiles]
348
+ return [Chem.MolToSmiles(iso, isomericSmiles=True) for iso in isomers]
349
+ except:
350
+ return [smiles]
351
+
352
+
353
+ # ============================================================================
354
+ # MODEL LOADING
355
+ # ============================================================================
356
+ @st.cache_resource
357
+ def load_model():
358
+ """Load the BBB model or fallback to descriptor predictor."""
359
+
360
+ # First try to load GNN model with weights
361
+ if TORCH_GEOMETRIC_AVAILABLE:
362
+ try:
363
+ encoder = StereoAwareEncoder(node_features=21, hidden_dim=128, num_layers=4)
364
+ model = BBBClassifier(encoder, hidden_dim=128)
365
+
366
+ # Try to load weights from various locations
367
+ possible_dirs = [
368
+ Path(__file__).parent / 'models',
369
+ Path('.') / 'models',
370
+ Path.home() / 'BBB_System' / 'models',
371
+ ]
372
+
373
+ model_files = [
374
+ 'bbb_stereo_v2_best.pth',
375
+ 'bbb_stereo_v2_fold4_best.pth',
376
+ 'bbb_stereo_v2_fold5_best.pth',
377
+ 'bbb_stereo_fold4_best.pth',
378
+ 'bbb_stereo_fold5_best.pth',
379
+ ]
380
+
381
+ for model_dir in possible_dirs:
382
+ for mf in model_files:
383
+ model_path = model_dir / mf
384
+ if model_path.exists():
385
+ try:
386
+ state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
387
+ model.load_state_dict(state_dict)
388
+ model.eval()
389
+ return {'type': 'gnn', 'model': model, 'name': mf}, None
390
+ except Exception as e:
391
+ continue
392
+ except Exception as e:
393
+ pass
394
+
395
+ # Fallback to descriptor-based predictor
396
+ if RDKIT_AVAILABLE:
397
+ predictor = DescriptorBBBPredictor()
398
+ return {'type': 'descriptor', 'model': predictor, 'name': 'Descriptor-Based (Fallback)'}, None
399
+
400
+ return None, "No prediction method available"
401
+
402
+
403
+ # ============================================================================
404
+ # PREDICTION
405
+ # ============================================================================
406
+ def predict_single(model_info, smiles):
407
+ """Predict BBB permeability for a single SMILES."""
408
+
409
+ if model_info['type'] == 'gnn':
410
+ model = model_info['model']
411
+ graph = smiles_to_graph(smiles)
412
+ if graph is None:
413
+ return None, "Invalid SMILES"
414
+
415
+ if graph.x.shape[1] != 21:
416
+ return None, f"Feature mismatch: expected 21, got {graph.x.shape[1]}"
417
+
418
+ graph.batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
419
+
420
+ with torch.no_grad():
421
+ logit = model(graph.x, graph.edge_index, graph.batch)
422
+ prob = torch.sigmoid(logit).item()
423
+
424
+ return prob, None
425
+
426
+ elif model_info['type'] == 'descriptor':
427
+ return model_info['model'].predict(smiles)
428
+
429
+ return None, "Unknown model type"
430
+
431
+
432
+ def predict_with_stereo_enumeration(model_info, smiles):
433
+ """Predict with stereoisomer enumeration."""
434
+ isomers = enumerate_stereoisomers(smiles)
435
+
436
+ predictions = []
437
+ for iso in isomers:
438
+ prob, err = predict_single(model_info, iso)
439
+ if prob is not None:
440
+ predictions.append((iso, prob))
441
+
442
+ if not predictions:
443
+ return None, "All stereoisomers failed"
444
+
445
+ probs = [p[1] for p in predictions]
446
+
447
+ return {
448
+ 'mean': np.mean(probs),
449
+ 'min': np.min(probs),
450
+ 'max': np.max(probs),
451
+ 'std': np.std(probs) if len(probs) > 1 else 0,
452
+ 'n_isomers': len(predictions),
453
+ 'predictions': predictions
454
+ }, None
455
+
456
+
457
+ # ============================================================================
458
+ # MOLECULAR PROPERTIES
459
+ # ============================================================================
460
+ def get_properties(smiles):
461
+ """Calculate molecular properties."""
462
+ if not RDKIT_AVAILABLE:
463
+ return None
464
+
465
+ mol = Chem.MolFromSmiles(smiles)
466
+ if mol is None:
467
+ return None
468
+
469
+ props = {
470
+ 'mw': Descriptors.MolWt(mol),
471
+ 'logp': Descriptors.MolLogP(mol),
472
+ 'tpsa': Descriptors.TPSA(mol),
473
+ 'hbd': Descriptors.NumHDonors(mol),
474
+ 'hba': Descriptors.NumHAcceptors(mol),
475
+ 'rotatable': Descriptors.NumRotatableBonds(mol),
476
+ 'formula': rdMolDescriptors.CalcMolFormula(mol),
477
+ 'atoms': mol.GetNumAtoms(),
478
+ }
479
+
480
+ # BBB rules (based on literature)
481
+ props['rules'] = {
482
+ 'mw': 150 <= props['mw'] <= 500,
483
+ 'logp': 0 <= props['logp'] <= 5,
484
+ 'tpsa': props['tpsa'] <= 90,
485
+ 'hbd': props['hbd'] <= 3,
486
+ 'hba': props['hba'] <= 7,
487
+ }
488
+ props['rules_passed'] = sum(props['rules'].values())
489
+
490
+ return props
491
+
492
+
493
+ def mol_to_image(smiles, size=(350, 250)):
494
+ """Generate molecule image."""
495
+ if not RDKIT_AVAILABLE:
496
+ return None
497
+
498
+ mol = Chem.MolFromSmiles(smiles)
499
+ if mol is None:
500
+ return None
501
+
502
+ try:
503
+ AllChem.Compute2DCoords(mol)
504
+ drawer = rdMolDraw2D.MolDraw2DCairo(size[0], size[1])
505
+ drawer.drawOptions().addStereoAnnotation = True
506
+ drawer.DrawMolecule(mol)
507
+ drawer.FinishDrawing()
508
+
509
+ img_data = drawer.GetDrawingText()
510
+ b64 = base64.b64encode(img_data).decode()
511
+ return f"data:image/png;base64,{b64}"
512
+ except:
513
+ return None
514
+
515
+
516
+ # ============================================================================
517
+ # COMMON MOLECULES DATABASE
518
+ # ============================================================================
519
+ MOLECULES = {
520
+ "caffeine": ("CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Caffeine"),
521
+ "aspirin": ("CC(=O)Oc1ccccc1C(=O)O", "Aspirin"),
522
+ "morphine": ("CN1CC[C@]23[C@H]4Oc5c(O)ccc(C[C@@H]1[C@@H]2C=C[C@@H]4O)c35", "Morphine"),
523
+ "cocaine": ("COC(=O)[C@H]1[C@@H]2CC[C@H](C2)N1C", "Cocaine"),
524
+ "dopamine": ("NCCc1ccc(O)c(O)c1", "Dopamine"),
525
+ "serotonin": ("NCCc1c[nH]c2ccc(O)cc12", "Serotonin"),
526
+ "ethanol": ("CCO", "Ethanol"),
527
+ "glucose": ("OC[C@H]1OC(O)[C@H](O)[C@@H](O)[C@@H]1O", "Glucose"),
528
+ "diazepam": ("CN1C(=O)CN=C(c2ccccc2)c3cc(Cl)ccc13", "Diazepam"),
529
+ "thc": ("CCCCCc1cc(O)c2[C@@H]3C=C(C)CC[C@H]3C(C)(C)Oc2c1", "THC"),
530
+ "nicotine": ("CN1CCC[C@H]1c2cccnc2", "Nicotine"),
531
+ "melatonin": ("CC(=O)NCCc1c[nH]c2ccc(OC)cc12", "Melatonin"),
532
+ "ibuprofen": ("CC(C)Cc1ccc(cc1)[C@H](C)C(=O)O", "Ibuprofen"),
533
+ "acetaminophen": ("CC(=O)Nc1ccc(O)cc1", "Acetaminophen"),
534
+ "fentanyl": ("CCC(=O)N(c1ccccc1)[C@@H]2CCN(CCc3ccccc3)CC2", "Fentanyl"),
535
+ "heroin": ("CC(=O)O[C@H]1C=C[C@H]2[C@H]3CC4=C5C(=C(OC(C)=O)C=C4C[C@@H]1[C@]23C)OCO5", "Heroin"),
536
+ "lsd": ("CCN(CC)C(=O)[C@H]1CN([C@@H]2Cc3cn(C)c4cccc(C2=C1)c34)C", "LSD"),
537
+ "mdma": ("CC(NC)Cc1ccc2OCOc2c1", "MDMA"),
538
+ "ketamine": ("CNC1(CCCCC1=O)c2ccccc2Cl", "Ketamine"),
539
+ "psilocybin": ("CN(C)CCc1c[nH]c2cccc(OP(=O)(O)O)c12", "Psilocybin"),
540
+ "atenolol": ("CC(C)NCC(O)COc1ccc(CC(N)=O)cc1", "Atenolol"),
541
+ "metformin": ("CN(C)C(=N)NC(=N)N", "Metformin"),
542
+ "penicillin": ("CC1(C)S[C@@H]2[C@H](NC(=O)Cc3ccccc3)C(=O)N2[C@H]1C(=O)O", "Penicillin"),
543
+ "amoxicillin": ("CC1(C)S[C@@H]2[C@H](NC(=O)[C@H](N)c3ccc(O)cc3)C(=O)N2[C@H]1C(=O)O", "Amoxicillin"),
544
+ }
545
+
546
+
547
+ def resolve_input(user_input):
548
+ """Resolve user input to SMILES."""
549
+ if not user_input:
550
+ return None, None, "Please enter a molecule"
551
+
552
+ if not RDKIT_AVAILABLE:
553
+ return None, None, "RDKit not available"
554
+
555
+ text = user_input.strip()
556
+
557
+ # Check if valid SMILES
558
+ if Chem.MolFromSmiles(text) is not None:
559
+ return text, "Custom Molecule", None
560
+
561
+ # Check database (case-insensitive)
562
+ key = text.lower().strip()
563
+ if key in MOLECULES:
564
+ return MOLECULES[key][0], MOLECULES[key][1], None
565
+
566
+ return None, None, f"Could not resolve '{text}'. Enter a valid SMILES or drug name."
567
+
568
+
569
+ # ============================================================================
570
+ # MAIN APP
571
+ # ============================================================================
572
+ def main():
573
+ # Header
574
+ st.markdown('<h1 class="main-header">StereoGNN-BBB</h1>', unsafe_allow_html=True)
575
+ st.markdown('<p class="sub-header">Blood-Brain Barrier Permeability Predictor | State-of-the-Art Performance</p>', unsafe_allow_html=True)
576
+
577
+ # Check dependencies
578
+ if not RDKIT_AVAILABLE:
579
+ st.error("RDKit is not installed. Please install it with: pip install rdkit")
580
+ st.stop()
581
+
582
+ # Load model
583
+ model_info, error = load_model()
584
+
585
+ if error:
586
+ st.error(f"Model loading failed: {error}")
587
+ st.stop()
588
+
589
+ # Show model info
590
+ is_gnn = model_info['type'] == 'gnn'
591
+
592
+ # Sidebar
593
+ with st.sidebar:
594
+ st.header("Model Info")
595
+
596
+ if is_gnn:
597
+ st.success(f"GNN Model: {model_info['name']}")
598
+ st.markdown("**Performance (External Validation):**")
599
+ st.metric("AUC", "0.9612")
600
+ st.metric("Sensitivity", "97.96%")
601
+ st.metric("Specificity", "65.25%")
602
+ else:
603
+ st.warning(f"Mode: {model_info['name']}")
604
+ st.markdown("""
605
+ <div class="info-box">
606
+ Using descriptor-based prediction.<br>
607
+ For full GNN accuracy, upload model weights to models/ folder.
608
+ </div>
609
+ """, unsafe_allow_html=True)
610
+
611
+ st.markdown("---")
612
+ st.subheader("Interpretation")
613
+ st.success("BBB+ (>=0.6): Crosses BBB")
614
+ st.warning("Moderate (0.4-0.6)")
615
+ st.error("BBB- (<0.4): Does not cross")
616
+
617
+ st.markdown("---")
618
+ st.subheader("Features")
619
+ st.markdown("""
620
+ - Stereo-aware predictions
621
+ - Stereoisomer enumeration
622
+ - Molecular property analysis
623
+ - BBB rule assessment
624
+ """)
625
+
626
+ st.markdown("---")
627
+ st.markdown("**Author:** Nabil Yasini-Ardekani")
628
+ st.markdown("[GitHub](https://github.com/abinittio)")
629
+
630
+ # Main input
631
+ st.subheader("Enter Molecule")
632
+
633
+ col1, col2 = st.columns([4, 1])
634
+ with col1:
635
+ user_input = st.text_input(
636
+ "SMILES or drug name",
637
+ placeholder="e.g., Caffeine, Aspirin, Morphine, or enter SMILES",
638
+ label_visibility="collapsed"
639
+ )
640
+ with col2:
641
+ predict_btn = st.button("Predict", type="primary", use_container_width=True)
642
+
643
+ # Quick examples
644
+ st.markdown("**Quick Examples:**")
645
+ examples = ["Caffeine", "Morphine", "THC", "Dopamine", "Glucose", "Atenolol"]
646
+ cols = st.columns(6)
647
+ for i, ex in enumerate(examples):
648
+ with cols[i]:
649
+ if st.button(ex, key=f"ex_{ex}", use_container_width=True):
650
+ st.session_state['mol_input'] = ex
651
+ st.rerun()
652
+
653
+ if 'mol_input' in st.session_state:
654
+ user_input = st.session_state['mol_input']
655
+ del st.session_state['mol_input']
656
+ predict_btn = True
657
+
658
+ # Stereo enumeration option
659
+ enumerate_stereo = st.checkbox("Enumerate stereoisomers", value=True,
660
+ help="Predict all possible stereoisomers and show range")
661
+
662
+ if predict_btn and user_input:
663
+ smiles, name, err = resolve_input(user_input)
664
+
665
+ if err:
666
+ st.error(err)
667
+ st.stop()
668
+
669
+ st.markdown(f"**{name}**: `{smiles}`")
670
+
671
+ with st.spinner("Predicting..."):
672
+ if enumerate_stereo:
673
+ result, pred_err = predict_with_stereo_enumeration(model_info, smiles)
674
+ else:
675
+ prob, pred_err = predict_single(model_info, smiles)
676
+ if prob is not None:
677
+ result = {'mean': prob, 'min': prob, 'max': prob, 'std': 0, 'n_isomers': 1}
678
+ else:
679
+ result = None
680
+
681
+ props = get_properties(smiles)
682
+ img = mol_to_image(smiles)
683
+
684
+ if pred_err:
685
+ st.error(f"Prediction failed: {pred_err}")
686
+ st.stop()
687
+
688
+ st.markdown("---")
689
+
690
+ # Results
691
+ col1, col2, col3 = st.columns([1.2, 1, 1])
692
+
693
+ score = result['mean']
694
+
695
+ with col1:
696
+ if score >= 0.6:
697
+ card_class = "prediction-positive"
698
+ category = "BBB+"
699
+ interp = "HIGH permeability - likely crosses BBB"
700
+ elif score >= 0.4:
701
+ card_class = "prediction-moderate"
702
+ category = "BBB+/-"
703
+ interp = "MODERATE - may partially cross"
704
+ else:
705
+ card_class = "prediction-negative"
706
+ category = "BBB-"
707
+ interp = "LOW permeability - unlikely to cross"
708
+
709
+ st.markdown(f"""
710
+ <div class="prediction-card {card_class}">
711
+ <h2 style="margin:0; font-size:2rem;">{category}</h2>
712
+ <h1 style="margin:0.3rem 0; font-size:2.5rem;">{score:.4f}</h1>
713
+ <p style="margin:0; font-size:0.9rem;">{interp}</p>
714
+ </div>
715
+ """, unsafe_allow_html=True)
716
+
717
+ if result['n_isomers'] > 1:
718
+ st.markdown(f"""
719
+ <div class="metric-box">
720
+ <b>Stereoisomer Analysis ({result['n_isomers']} isomers)</b><br>
721
+ Range: {result['min']:.4f} - {result['max']:.4f}<br>
722
+ Std Dev: {result['std']:.4f}
723
+ </div>
724
+ """, unsafe_allow_html=True)
725
+
726
+ with col2:
727
+ if img:
728
+ st.image(img, caption=name, use_container_width=True)
729
+ else:
730
+ st.info("Molecule image not available")
731
+
732
+ with col3:
733
+ if props:
734
+ st.markdown(f"**Formula:** {props['formula']}")
735
+ st.markdown(f"**MW:** {props['mw']:.1f} Da")
736
+ st.markdown(f"**LogP:** {props['logp']:.2f}")
737
+ st.markdown(f"**TPSA:** {props['tpsa']:.1f} A²")
738
+ st.markdown(f"**H-Donors:** {props['hbd']}")
739
+ st.markdown(f"**H-Acceptors:** {props['hba']}")
740
+
741
+ rules_color = "green" if props['rules_passed'] >= 4 else "orange" if props['rules_passed'] >= 3 else "red"
742
+ st.markdown(f"**BBB Rules:** :{rules_color}[{props['rules_passed']}/5 passed]")
743
+
744
+ # Download section
745
+ st.markdown("---")
746
+ st.subheader("Export Results")
747
+
748
+ report = {
749
+ 'molecule': name,
750
+ 'smiles': smiles,
751
+ 'bbb_score': round(score, 4),
752
+ 'category': category,
753
+ 'interpretation': interp,
754
+ 'n_stereoisomers': result['n_isomers'],
755
+ 'score_min': round(result['min'], 4),
756
+ 'score_max': round(result['max'], 4),
757
+ 'score_std': round(result['std'], 4),
758
+ 'model_type': model_info['type'],
759
+ 'model_name': model_info['name'],
760
+ 'timestamp': datetime.now().isoformat()
761
+ }
762
+
763
+ if props:
764
+ report.update({
765
+ 'formula': props['formula'],
766
+ 'molecular_weight': round(props['mw'], 2),
767
+ 'logp': round(props['logp'], 2),
768
+ 'tpsa': round(props['tpsa'], 2),
769
+ 'h_donors': props['hbd'],
770
+ 'h_acceptors': props['hba'],
771
+ 'bbb_rules_passed': props['rules_passed'],
772
+ })
773
+
774
+ col1, col2, col3 = st.columns(3)
775
+ with col1:
776
+ st.download_button(
777
+ "Download JSON",
778
+ json.dumps(report, indent=2),
779
+ f"{name.replace(' ','_')}_bbb_prediction.json",
780
+ "application/json",
781
+ use_container_width=True
782
+ )
783
+ with col2:
784
+ df = pd.DataFrame([report])
785
+ st.download_button(
786
+ "Download CSV",
787
+ df.to_csv(index=False),
788
+ f"{name.replace(' ','_')}_bbb_prediction.csv",
789
+ "text/csv",
790
+ use_container_width=True
791
+ )
792
+ with col3:
793
+ # Create simple text report
794
+ text_report = f"""BBB Permeability Prediction Report
795
+ =====================================
796
+ Molecule: {name}
797
+ SMILES: {smiles}
798
+ Score: {score:.4f}
799
+ Category: {category}
800
+ Interpretation: {interp}
801
+
802
+ Model: {model_info['name']}
803
+ Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
804
+
805
+ Molecular Properties:
806
+ - Formula: {props['formula'] if props else 'N/A'}
807
+ - MW: {props['mw']:.1f if props else 'N/A'} Da
808
+ - LogP: {props['logp']:.2f if props else 'N/A'}
809
+ - TPSA: {props['tpsa']:.1f if props else 'N/A'} A²
810
+ - BBB Rules: {props['rules_passed'] if props else 'N/A'}/5 passed
811
+
812
+ Generated by StereoGNN-BBB
813
+ Author: Nabil Yasini-Ardekani
814
+ """
815
+ st.download_button(
816
+ "Download TXT",
817
+ text_report,
818
+ f"{name.replace(' ','_')}_bbb_prediction.txt",
819
+ "text/plain",
820
+ use_container_width=True
821
+ )
822
+
823
+ # Footer with available molecules
824
+ with st.expander("Available Drug Names (click to expand)"):
825
+ drug_list = sorted(MOLECULES.keys())
826
+ cols = st.columns(5)
827
+ for i, drug in enumerate(drug_list):
828
+ with cols[i % 5]:
829
+ st.write(f"• {drug.capitalize()}")
830
+
831
+
832
+ if __name__ == "__main__":
833
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # StereoGNN-BBB - Streamlit Cloud Deployment
2
+ # Blood-Brain Barrier Permeability Predictor
3
+ # Author: Nabil Yasini-Ardekani
4
+
5
+ streamlit>=1.28.0
6
+ numpy>=1.24.0
7
+ pandas>=2.0.0
8
+ rdkit>=2023.9.1
9
+ torch>=2.0.0
10
+ torch-geometric>=2.4.0