AbdoIR commited on
Commit
5c5b09c
·
verified ·
1 Parent(s): 61b92bb

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +288 -0
  2. requirements.txt +15 -0
main.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI Backend for Drug-Target Binding Affinity Prediction (KC-DTA)"""
2
+
3
+ import os
4
+ import sys
5
+ import logging
6
+ from pathlib import Path
7
+ from functools import lru_cache
8
+
9
+ import torch
10
+ from torch_geometric import data as DATA
11
+ from rdkit import Chem
12
+ from rdkit.Chem.rdchem import ValenceType
13
+ from contextlib import asynccontextmanager
14
+
15
+ from fastapi import FastAPI, HTTPException
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel, Field, ConfigDict
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ============================================================================
24
+ # Configuration
25
+ # ============================================================================
26
+
27
+ # CORS: Use environment variable for allowed origins (comma-separated)
28
+ ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:5173").split(",")
29
+
30
+ # Input validation limits (prevent DoS)
31
+ MAX_SMILES_LENGTH = 500
32
+ MAX_PROTEIN_LENGTH = 5000
33
+
34
+ # Add KCDTA to path
35
+ sys.path.insert(0, str(Path(__file__).parent.parent / "KCDTA"))
36
+ from models.cnn import cnn
37
+
38
+ # ============================================================================
39
+ # Pre-computed Constants
40
+ # ============================================================================
41
+
42
+ SEQ_VOC = "ACDEFGHIKLMNPQRSTVWXY"
43
+ SEQ_VOC_SET = frozenset(SEQ_VOC) # Frozenset for O(1) lookup
44
+ L = 21 # len(SEQ_VOC) - hardcoded to avoid function call
45
+ AA_TO_IDX = {aa: idx for idx, aa in enumerate(SEQ_VOC)}
46
+
47
+ # Pre-compute atom symbol lookup (44 symbols)
48
+ _ATOM_SYMBOLS = ('C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
49
+ 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag',
50
+ 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni',
51
+ 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb')
52
+ ATOM_SYMBOL_IDX = {s: i for i, s in enumerate(_ATOM_SYMBOLS)}
53
+
54
+ # Pre-compute all 6 permutation index tuples for 3-mers
55
+ PERM_INDICES = ((0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0))
56
+
57
+ # Pre-allocated reusable tensors (will be set on startup with correct device)
58
+ _EMPTY_EDGE = None
59
+ _ZERO_Y = None
60
+
61
+ # ============================================================================
62
+ # Optimized Feature Extraction
63
+ # ============================================================================
64
+
65
+ @lru_cache(maxsize=50000)
66
+ def _atom_feat(symbol: str, degree: int, num_hs: int, valence: int, aromatic: bool) -> tuple:
67
+ """Cached atom features - returns normalized 78-dim tuple."""
68
+ feat = [0.0] * 78
69
+ feat[ATOM_SYMBOL_IDX.get(symbol, 43)] = 1.0
70
+ feat[44 + min(degree, 10)] = 1.0
71
+ feat[55 + min(num_hs, 10)] = 1.0
72
+ feat[66 + min(valence, 10)] = 1.0
73
+ feat[77] = 1.0 if aromatic else 0.0
74
+ s = sum(feat)
75
+ return tuple(f / s for f in feat)
76
+
77
+
78
+ @lru_cache(maxsize=10000)
79
+ def smile_to_graph(smile: str) -> tuple:
80
+ """Convert SMILES to molecular graph (cached). Returns (n_atoms, features, edges)."""
81
+ mol = Chem.MolFromSmiles(smile)
82
+ if mol is None:
83
+ raise ValueError(f"Invalid SMILES: {smile}")
84
+
85
+ # Extract features using cached atom_feat
86
+ features = tuple(
87
+ _atom_feat(a.GetSymbol(), a.GetDegree(), a.GetTotalNumHs(), a.GetValence(ValenceType.IMPLICIT), a.GetIsAromatic())
88
+ for a in mol.GetAtoms()
89
+ )
90
+
91
+ # Build edge list - flat tuple for faster tensor creation
92
+ edges = []
93
+ for b in mol.GetBonds():
94
+ i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
95
+ edges.extend((i, j, j, i)) # Both directions
96
+
97
+ return len(features), features, tuple(edges)
98
+
99
+
100
+ @lru_cache(maxsize=2000)
101
+ def protein_features(seq: str) -> tuple:
102
+ """Cached combined 2D+3D protein features. Returns (flat_2d, flat_3d)."""
103
+ # 2D: Cartesian product of amino acid counts
104
+ counts = [0] * L
105
+ for c in seq:
106
+ idx = AA_TO_IDX.get(c)
107
+ if idx is not None:
108
+ counts[idx] += 1
109
+
110
+ flat_2d = tuple(counts[i] * counts[j] for i in range(L) for j in range(L))
111
+
112
+ # 3D: K-mers with permutations
113
+ pro_3d = [0.0] * (L * L * L)
114
+ seq_len = len(seq)
115
+
116
+ # Count trimers in one pass
117
+ trimer_counts = {}
118
+ for i in range(seq_len - 2):
119
+ t = seq[i:i+3]
120
+ trimer_counts[t] = trimer_counts.get(t, 0) + 1
121
+
122
+ # Fill 3D matrix
123
+ for trimer, count in trimer_counts.items():
124
+ try:
125
+ idx = tuple(AA_TO_IDX[c] for c in trimer)
126
+ except KeyError:
127
+ continue
128
+ for p in PERM_INDICES:
129
+ a, b, c = idx[p[0]], idx[p[1]], idx[p[2]]
130
+ pro_3d[a * L * L + b * L + c] += count
131
+
132
+ # Normalize
133
+ max_val = max(pro_3d) if pro_3d else 0
134
+ if max_val > 0:
135
+ pro_3d = tuple(v / max_val for v in pro_3d)
136
+ else:
137
+ pro_3d = tuple(pro_3d)
138
+
139
+ return flat_2d, pro_3d
140
+
141
+
142
+ def create_graph_data(smiles: str, protein_seq: str, device: torch.device) -> DATA.Data:
143
+ """Create PyTorch Geometric Data object directly on device."""
144
+ n_atoms, features, edges = smile_to_graph(smiles)
145
+ flat_2d, flat_3d = protein_features(protein_seq)
146
+
147
+ # Create tensors on device
148
+ x = torch.tensor(features, dtype=torch.float32, device=device)
149
+
150
+ if edges:
151
+ edge_idx = torch.tensor(edges, dtype=torch.long, device=device).view(2, -1)
152
+ else:
153
+ edge_idx = _EMPTY_EDGE if _EMPTY_EDGE is not None and _EMPTY_EDGE.device == device else torch.empty((2, 0), dtype=torch.long, device=device)
154
+
155
+ data = DATA.Data(x=x, edge_index=edge_idx, y=_ZERO_Y)
156
+ data.dcpro = torch.tensor(flat_2d, dtype=torch.float32, device=device).view(1, L, L)
157
+ data.target = torch.tensor(flat_3d, dtype=torch.float32, device=device).view(1, L, L, L)
158
+ data.batch = torch.zeros(n_atoms, dtype=torch.long, device=device)
159
+
160
+ return data
161
+
162
+
163
+ # ============================================================================
164
+ # FastAPI Application
165
+ # ============================================================================
166
+
167
+ from typing import Optional
168
+
169
+ class AppState:
170
+ __slots__ = ('model', 'device', 'empty_edge', 'zero_y') # Slots for memory efficiency
171
+ def __init__(self):
172
+ self.model: Optional[cnn] = None
173
+ self.device: Optional[torch.device] = None
174
+ self.empty_edge: Optional[torch.Tensor] = None
175
+ self.zero_y: Optional[torch.Tensor] = None
176
+
177
+ state = AppState()
178
+
179
+
180
+ @asynccontextmanager
181
+ async def lifespan(app: FastAPI):
182
+ global _EMPTY_EDGE, _ZERO_Y
183
+
184
+ # Startup: Load model with optimizations
185
+ model_path = Path(__file__).parent.parent / "KCDTA" / "model_cnn_kiba.model"
186
+ state.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
187
+
188
+ # Load model
189
+ state.model = cnn()
190
+ state.model.load_state_dict(torch.load(model_path, map_location=state.device, weights_only=True))
191
+ state.model.to(state.device)
192
+ state.model.eval() # Set to evaluation mode (disables dropout)
193
+
194
+ # Freeze parameters for inference (additional optimization)
195
+ for param in state.model.parameters():
196
+ param.requires_grad = False
197
+
198
+ # Pre-allocate reusable tensors
199
+ _EMPTY_EDGE = torch.empty((2, 0), dtype=torch.long, device=state.device)
200
+ _ZERO_Y = torch.zeros(1, device=state.device)
201
+
202
+ logger.info(f"Model loaded on {state.device} with {sum(p.numel() for p in state.model.parameters()):,} parameters")
203
+ yield
204
+
205
+ # Shutdown
206
+ state.model = None
207
+
208
+
209
+ app = FastAPI(
210
+ title="Drug-Target Binding Affinity Prediction API",
211
+ version="1.0.0",
212
+ lifespan=lifespan,
213
+ docs_url="/docs",
214
+ redoc_url="/redoc",
215
+ )
216
+
217
+ # Secure CORS configuration
218
+ app.add_middleware(
219
+ CORSMiddleware,
220
+ allow_origins=ALLOWED_ORIGINS,
221
+ allow_credentials=True,
222
+ allow_methods=["GET", "POST"],
223
+ allow_headers=["Content-Type", "Authorization"],
224
+ )
225
+
226
+
227
+ class PredictionRequest(BaseModel):
228
+ model_config = ConfigDict(extra="ignore")
229
+ smiles: str = Field(..., min_length=1, max_length=MAX_SMILES_LENGTH, description="SMILES representation of the drug molecule")
230
+ protein_sequence: str = Field(..., min_length=1, max_length=MAX_PROTEIN_LENGTH, description="Amino acid sequence of the target protein")
231
+
232
+
233
+ class PredictionResponse(BaseModel):
234
+ smiles: str
235
+ protein_sequence: str
236
+ binding_affinity: float
237
+ model_used: str = "KIBA"
238
+
239
+
240
+ @app.get("/health")
241
+ async def health():
242
+ return {"status": "healthy", "model_loaded": state.model is not None, "device": str(state.device)}
243
+
244
+
245
+ @app.post("/predict", response_model=PredictionResponse)
246
+ async def predict(request: PredictionRequest):
247
+ if state.model is None:
248
+ raise HTTPException(503, "Model not loaded")
249
+
250
+ smiles = request.smiles.strip()
251
+ seq = request.protein_sequence.strip().upper()
252
+
253
+ # Fast validation using pre-computed frozenset
254
+ invalid_aa = set(seq) - SEQ_VOC_SET
255
+ if invalid_aa:
256
+ raise HTTPException(400, f"Invalid amino acids found: {invalid_aa}. Valid: {SEQ_VOC}")
257
+
258
+ # Validate SMILES (this also caches valid molecules in RDKit)
259
+ mol = Chem.MolFromSmiles(smiles)
260
+ if mol is None:
261
+ raise HTTPException(400, f"Invalid SMILES string: unable to parse molecule")
262
+
263
+ # Additional molecule validation
264
+ if mol.GetNumAtoms() == 0:
265
+ raise HTTPException(400, "SMILES represents an empty molecule")
266
+
267
+ try:
268
+ data = create_graph_data(smiles, seq, state.device)
269
+
270
+ with torch.inference_mode():
271
+ affinity = state.model(data).item()
272
+ except Exception as e:
273
+ logger.error(f"Prediction failed: {e}")
274
+ raise HTTPException(500, "Prediction failed due to internal error")
275
+
276
+ return PredictionResponse(smiles=smiles, protein_sequence=seq, binding_affinity=round(affinity, 4))
277
+
278
+
279
+ if __name__ == "__main__":
280
+ import uvicorn
281
+ # Hugging Face Spaces requires port 7860 and a single worker
282
+ uvicorn.run(
283
+ "main:app", # Entrypoint for Hugging Face Spaces
284
+ host="0.0.0.0",
285
+ port=int(os.getenv("PORT", 7860)),
286
+ log_level=os.getenv("LOG_LEVEL", "info"),
287
+ factory=False,
288
+ )
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web framework
2
+ fastapi>=0.104.0
3
+ uvicorn[standard]>=0.24.0
4
+ pydantic>=2.0.0
5
+
6
+ # Machine Learning
7
+ torch>=2.0.0
8
+ torch-geometric>=2.4.0
9
+
10
+ # Chemistry
11
+ rdkit>=2023.3.1
12
+
13
+ # Production (optional but recommended)
14
+ # gunicorn>=21.0.0 # WSGI server for production
15
+ # python-multipart>=0.0.6 # For form data if needed