mk6783336 commited on
Commit
0e34be9
Β·
verified Β·
1 Parent(s): 2fcee7a

Upload api.py

Browse files
Files changed (1) hide show
  1. api.py +362 -0
api.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepCRISPR Enterprise API β€” 2-Stage Pipeline
3
+ =============================================
4
+ Stage 1: PyTorch CRISPRMegaModel β†’ 256-dim embeddings
5
+ Stage 2: AutoGluon TabularPredictor β†’ Safety prediction
6
+
7
+ Takes sgRNA + off-target sequences, runs them through the trained neural
8
+ network to extract learned embeddings, combines with hand-crafted bio
9
+ features, and feeds the full feature vector to AutoGluon for the final
10
+ safety confidence score.
11
+
12
+ Architected by Mujahid
13
+
14
+ Usage:
15
+ uvicorn api:app --reload
16
+ β†’ Docs: http://127.0.0.1:8000/docs
17
+ """
18
+
19
+ import os
20
+ import re
21
+ import warnings
22
+ from datetime import datetime, timezone
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ from pathlib import Path
27
+ from fastapi import FastAPI, HTTPException
28
+ from fastapi.responses import HTMLResponse
29
+ from pydantic import BaseModel, Field
30
+
31
+ # ─────────────────────────── APP INSTANCE ───────────────────────────────────
32
+
33
+ app = FastAPI(
34
+ title="DeepCRISPR Enterprise API",
35
+ version="1.0.0",
36
+ description=(
37
+ "2-Stage AI pipeline for CRISPR-Cas9 off-target safety prediction.\n\n"
38
+ "**Stage 1:** PyTorch CRISPRMegaModel (CNN + Transformer + BiLSTM) β†’ "
39
+ "256-dimensional learned embeddings.\n\n"
40
+ "**Stage 2:** AutoGluon TabularPredictor β†’ Final safety confidence.\n\n"
41
+ "**Architected by Mujahid**"
42
+ ),
43
+ contact={"name": "Mujahid"},
44
+ )
45
+
46
+
47
+ # ─────────────────────────── MODEL LOADING ──────────────────────────────────
48
+
49
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
50
+
51
+ # Paths β€” check both root and subfolder locations
52
+ PTH_CANDIDATES = [
53
+ os.path.join(BASE_DIR, "mega_model_best.pth"),
54
+ os.path.join(BASE_DIR, "DeepCRISPR_Mega_Model_Full", "mega_model_best.pth"),
55
+ ]
56
+ AG_CANDIDATES = [
57
+ os.path.join(BASE_DIR, "autogluon_mega"),
58
+ os.path.join(BASE_DIR, "DeepCRISPR_Mega_Model_Full", "autogluon_mega"),
59
+ ]
60
+
61
+ # ── Stage 1: PyTorch ──
62
+ torch_model = None
63
+ torch_device = None
64
+
65
+ try:
66
+ import torch
67
+ from core_engine import CRISPRMegaModel, encode_pair, extract_bio_features, cfg
68
+
69
+ torch_device = torch.device('cpu')
70
+
71
+ # Find the .pth file
72
+ pth_path = None
73
+ for candidate in PTH_CANDIDATES:
74
+ if os.path.exists(candidate):
75
+ pth_path = candidate
76
+ break
77
+
78
+ if pth_path:
79
+ torch_model = CRISPRMegaModel()
80
+ checkpoint = torch.load(pth_path, map_location=torch_device, weights_only=False)
81
+ # Handle wrapped state dicts (checkpoint saves with extra metadata)
82
+ if isinstance(checkpoint, dict):
83
+ if 'state' in checkpoint:
84
+ state_dict = checkpoint['state'] # your Kaggle format
85
+ elif 'model_state_dict' in checkpoint:
86
+ state_dict = checkpoint['model_state_dict']
87
+ elif 'state_dict' in checkpoint:
88
+ state_dict = checkpoint['state_dict']
89
+ else:
90
+ state_dict = checkpoint # assume bare state dict
91
+ else:
92
+ state_dict = checkpoint
93
+ torch_model.load_state_dict(state_dict)
94
+ torch_model.eval()
95
+ print(f"βœ… PyTorch CRISPRMegaModel loaded from: {pth_path}")
96
+ else:
97
+ warnings.warn("⚠️ mega_model_best.pth not found. PyTorch stage disabled.")
98
+
99
+ except ImportError as e:
100
+ warnings.warn(f"⚠️ PyTorch / core_engine import failed: {e}. Install with: pip install torch")
101
+ except Exception as e:
102
+ warnings.warn(f"⚠️ PyTorch model load error: {e}. Running without neural embeddings.")
103
+ torch_model = None
104
+
105
+
106
+ # ── Stage 2: AutoGluon ──
107
+ ag_predictor = None
108
+
109
+ try:
110
+ from autogluon.tabular import TabularPredictor
111
+
112
+ ag_path = None
113
+ for candidate in AG_CANDIDATES:
114
+ if os.path.isdir(candidate):
115
+ ag_path = candidate
116
+ break
117
+
118
+ if ag_path:
119
+ ag_predictor = TabularPredictor.load(ag_path)
120
+ print(f"βœ… AutoGluon predictor loaded from: {ag_path}")
121
+ else:
122
+ warnings.warn("⚠️ autogluon_mega/ directory not found. AutoGluon stage disabled.")
123
+
124
+ except ImportError:
125
+ warnings.warn("⚠️ AutoGluon not installed. Install with: pip install autogluon.tabular")
126
+ except Exception as e:
127
+ warnings.warn(f"⚠️ AutoGluon load error: {e}")
128
+
129
+
130
+ # ── Status summary ──
131
+ PIPELINE_STATUS = {
132
+ "pytorch": "loaded" if torch_model is not None else "unavailable",
133
+ "autogluon": "loaded" if ag_predictor is not None else "unavailable",
134
+ }
135
+
136
+ if torch_model and ag_predictor:
137
+ PIPELINE_MODE = "live"
138
+ print("πŸš€ LIVE MODE β€” Full 2-stage pipeline active.")
139
+ elif torch_model:
140
+ PIPELINE_MODE = "partial-pytorch"
141
+ print("⚑ PARTIAL MODE β€” PyTorch only (no AutoGluon).")
142
+ else:
143
+ PIPELINE_MODE = "demo"
144
+ print("⚑ DEMO MODE β€” Returning synthetic predictions.")
145
+
146
+
147
+ # ─────────────────────────── PYDANTIC SCHEMAS ───────────────────────────────
148
+
149
+ class GuideRNAInput(BaseModel):
150
+ """Input schema: an sgRNA sequence and its candidate off-target site."""
151
+
152
+ sgRNA_seq: str = Field(
153
+ ...,
154
+ min_length=10,
155
+ max_length=30,
156
+ description="The 20–23nt sgRNA guide sequence (A/T/C/G/U/N/-).",
157
+ json_schema_extra={"examples": ["GAGTCCGAGCAGAAGAAGAA"]},
158
+ )
159
+ off_target_seq: str = Field(
160
+ ...,
161
+ min_length=10,
162
+ max_length=30,
163
+ description="The candidate off-target DNA site (A/T/C/G/N/-).",
164
+ json_schema_extra={"examples": ["GAGTCCAAGCAGAAGAAGAA"]},
165
+ )
166
+
167
+
168
+ class SafetyScoreResponse(BaseModel):
169
+ """Output schema for the safety prediction."""
170
+
171
+ sgRNA_seq: str
172
+ off_target_seq: str
173
+ safety_confidence_percentage: float = Field(
174
+ ..., ge=0, le=100,
175
+ description="AI-predicted safety confidence (0–100%). Higher = safer.",
176
+ )
177
+ status: str = Field(
178
+ ..., description="'Safe' (>80%) or 'Risky' (≀80%).",
179
+ )
180
+ n_mismatches: int = Field(
181
+ ..., description="Number of mismatches between sgRNA and off-target.",
182
+ )
183
+ mode: str = Field(
184
+ ..., description="Pipeline mode: 'live', 'partial-pytorch', or 'demo'.",
185
+ )
186
+ pipeline: dict = Field(
187
+ ..., description="Status of each pipeline stage.",
188
+ )
189
+ timestamp: str
190
+
191
+
192
+ # ─────────────────────────── INFERENCE HELPERS ──────────────────────────────
193
+
194
+ def _run_pytorch_inference(sgrna: str, offtarget: str) -> np.ndarray:
195
+ """Run Stage 1: PyTorch model β†’ 256-dim embedding vector."""
196
+ sg_tok, off_tok, mm_tok = encode_pair(sgrna, offtarget)
197
+
198
+ sg_t = torch.tensor([sg_tok], dtype=torch.long, device=torch_device)
199
+ off_t = torch.tensor([off_tok], dtype=torch.long, device=torch_device)
200
+ mm_t = torch.tensor([mm_tok], dtype=torch.long, device=torch_device)
201
+
202
+ with torch.no_grad():
203
+ output = torch_model(sg_t, off_t, mm_t)
204
+
205
+ return output['embedding'].cpu().numpy().flatten() # (256,)
206
+
207
+
208
+ def _build_feature_row(sgrna: str, offtarget: str, embeddings: np.ndarray) -> pd.DataFrame:
209
+ """Combine 256 neural embeddings + bio features into a single-row DataFrame."""
210
+ # Embedding columns: emb_0 … emb_255
211
+ row = {f'emb_{i}': float(embeddings[i]) for i in range(len(embeddings))}
212
+
213
+ # Biological features
214
+ bio = extract_bio_features(sgrna, offtarget)
215
+ row.update(bio)
216
+
217
+ return pd.DataFrame([row])
218
+
219
+
220
+ # ─────────────────────────── ENDPOINTS ──────────────────────────────────────
221
+
222
+ @app.get("/", response_class=HTMLResponse, tags=["UI"])
223
+ def dashboard():
224
+ """Premium web dashboard for DeepCRISPR Enterprise."""
225
+ html_path = Path(BASE_DIR) / "templates" / "dashboard.html"
226
+ return HTMLResponse(content=html_path.read_text(encoding="utf-8"), status_code=200)
227
+
228
+
229
+ @app.get("/health", tags=["Health"])
230
+ def health_check():
231
+ """Health check and pipeline status."""
232
+ return {
233
+ "message": "DeepCRISPR Enterprise API is Live.",
234
+ "mode": PIPELINE_MODE,
235
+ "pipeline": PIPELINE_STATUS,
236
+ }
237
+
238
+
239
+ @app.post(
240
+ "/predict/safety-score",
241
+ response_model=SafetyScoreResponse,
242
+ tags=["Prediction"],
243
+ summary="Predict off-target safety for an sgRNA / off-target pair",
244
+ )
245
+ def predict_safety_score(payload: GuideRNAInput):
246
+ """
247
+ **2-Stage AI Pipeline:**
248
+
249
+ 1. The sgRNA + off-target pair is tokenized and passed through the
250
+ PyTorch CRISPRMegaModel (CNN + Transformer + BiLSTM) to extract
251
+ 256-dimensional learned embeddings.
252
+
253
+ 2. The embeddings are combined with 50 hand-crafted biological features
254
+ and fed to the AutoGluon TabularPredictor for the final safety score.
255
+
256
+ **Classification:** Safe (>80%) or Risky (≀80%).
257
+ """
258
+ sgrna = payload.sgRNA_seq.strip().upper().replace('U', 'T')
259
+ offtarget = payload.off_target_seq.strip().upper().replace('U', 'T')
260
+
261
+ # ── Validate characters ──
262
+ valid_chars = re.compile(r'^[ATCGN\-]+$')
263
+ if not valid_chars.match(sgrna):
264
+ raise HTTPException(
265
+ status_code=422,
266
+ detail="sgRNA_seq contains invalid characters. Allowed: A, T, C, G, U, N, -",
267
+ )
268
+ if not valid_chars.match(offtarget):
269
+ raise HTTPException(
270
+ status_code=422,
271
+ detail="off_target_seq contains invalid characters. Allowed: A, T, C, G, U, N, -",
272
+ )
273
+
274
+ # ── Count mismatches for response ──
275
+ sg_padded = sgrna[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N')
276
+ off_padded = offtarget[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N')
277
+ n_mm = sum(1 for a, b in zip(sg_padded, off_padded) if a != b)
278
+
279
+ # ── Stage 1: PyTorch embeddings ──
280
+ if torch_model is not None:
281
+ try:
282
+ embeddings = _run_pytorch_inference(sgrna, offtarget)
283
+ except Exception as e:
284
+ raise HTTPException(
285
+ status_code=500,
286
+ detail=f"PyTorch inference failed: {e}",
287
+ )
288
+ else:
289
+ # Synthetic 256-dim embeddings for demo mode
290
+ import hashlib
291
+ seed = int(hashlib.md5((sgrna + offtarget).encode()).hexdigest()[:8], 16)
292
+ rng = np.random.RandomState(seed)
293
+ embeddings = rng.randn(256).astype(np.float32) * 0.1
294
+
295
+ # ── Build feature DataFrame ──
296
+ bio_feats = extract_bio_features(sgrna, offtarget)
297
+
298
+ row = {f'emb_{i}': float(embeddings[i]) for i in range(len(embeddings))}
299
+ row.update(bio_feats)
300
+ df_features = pd.DataFrame([row])
301
+
302
+ # ── Stage 2: AutoGluon prediction ──
303
+ if ag_predictor is not None:
304
+ try:
305
+ proba = ag_predictor.predict_proba(df_features)
306
+ if hasattr(proba, 'shape') and len(proba.shape) == 2:
307
+ safety_pct = float(proba.iloc[0, 0] * 100)
308
+ else:
309
+ safety_pct = float(proba.iloc[0] * 100)
310
+ except Exception as e:
311
+ raise HTTPException(
312
+ status_code=500,
313
+ detail=f"AutoGluon prediction failed: {e}",
314
+ )
315
+ elif torch_model is not None:
316
+ # Partial mode: use PyTorch off_prob directly
317
+ sg_tok, off_tok, mm_tok = encode_pair(sgrna, offtarget)
318
+ sg_t = torch.tensor([sg_tok], dtype=torch.long, device=torch_device)
319
+ off_t = torch.tensor([off_tok], dtype=torch.long, device=torch_device)
320
+ mm_t = torch.tensor([mm_tok], dtype=torch.long, device=torch_device)
321
+ with torch.no_grad():
322
+ output = torch_model(sg_t, off_t, mm_t)
323
+ safety_pct = float((1 - output['off_prob'].item()) * 100)
324
+ else:
325
+ # Demo mode: hash-based deterministic score
326
+ import hashlib
327
+ seed = int(hashlib.md5((sgrna + offtarget).encode()).hexdigest()[:8], 16)
328
+ rng = np.random.RandomState(seed)
329
+ safety_pct = round(float(rng.uniform(0, 100)), 2)
330
+
331
+ safety_pct = round(max(0.0, min(100.0, safety_pct)), 2)
332
+ status = "Safe" if safety_pct > 80 else "Risky"
333
+
334
+ return SafetyScoreResponse(
335
+ sgRNA_seq=sgrna,
336
+ off_target_seq=offtarget,
337
+ safety_confidence_percentage=safety_pct,
338
+ status=status,
339
+ n_mismatches=n_mm,
340
+ mode=PIPELINE_MODE,
341
+ pipeline=PIPELINE_STATUS,
342
+ timestamp=datetime.now(timezone.utc).isoformat(),
343
+ )
344
+
345
+
346
+ # ─────────────────────────── LOCAL SERVER ───────────────────────────────────
347
+
348
+ if __name__ == "__main__":
349
+ import uvicorn
350
+
351
+ print("=" * 60)
352
+ print(" DeepCRISPR Enterprise API β€” 2-Stage Pipeline")
353
+ print(" Architected by Mujahid")
354
+ print("=" * 60)
355
+ print(f" PyTorch: {PIPELINE_STATUS['pytorch']}")
356
+ print(f" AutoGluon: {PIPELINE_STATUS['autogluon']}")
357
+ print(f" Mode: {PIPELINE_MODE.upper()}")
358
+ print(" Starting server β†’ http://127.0.0.1:8000")
359
+ print(" Swagger UI β†’ http://127.0.0.1:8000/docs")
360
+ print("=" * 60)
361
+
362
+ uvicorn.run(app, host="127.0.0.1", port=8000)