Spaces:
Sleeping
Sleeping
| """ | |
| DeepCRISPR Enterprise API β 2-Stage Pipeline | |
| ============================================= | |
| Stage 1: PyTorch CRISPRMegaModel β 256-dim embeddings | |
| Stage 2: AutoGluon TabularPredictor β Safety prediction | |
| Takes sgRNA + off-target sequences, runs them through the trained neural | |
| network to extract learned embeddings, combines with hand-crafted bio | |
| features, and feeds the full feature vector to AutoGluon for the final | |
| safety confidence score. | |
| Architected by Mujahid | |
| Usage: | |
| uvicorn api:app --reload | |
| β Docs: http://127.0.0.1:8000/docs | |
| """ | |
| import os | |
| import re | |
| import warnings | |
| from datetime import datetime, timezone | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel, Field | |
| # βββββββββββββββββββββββββββ APP INSTANCE βββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="DeepCRISPR Enterprise API", | |
| version="1.0.0", | |
| description=( | |
| "2-Stage AI pipeline for CRISPR-Cas9 off-target safety prediction.\n\n" | |
| "**Stage 1:** PyTorch CRISPRMegaModel (CNN + Transformer + BiLSTM) β " | |
| "256-dimensional learned embeddings.\n\n" | |
| "**Stage 2:** AutoGluon TabularPredictor β Final safety confidence.\n\n" | |
| "**Architected by Mujahid**" | |
| ), | |
| contact={"name": "Mujahid"}, | |
| ) | |
| # βββββββββββββββββββββββββββ MODEL LOADING ββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Paths β check both root and subfolder locations | |
| PTH_CANDIDATES = [ | |
| os.path.join(BASE_DIR, "mega_model_best.pth"), | |
| os.path.join(BASE_DIR, "DeepCRISPR_Mega_Model_Full", "mega_model_best.pth"), | |
| ] | |
| AG_CANDIDATES = [ | |
| os.path.join(BASE_DIR, "autogluon_mega"), | |
| os.path.join(BASE_DIR, "DeepCRISPR_Mega_Model_Full", "autogluon_mega"), | |
| ] | |
| # ββ Stage 1: PyTorch ββ | |
| torch_model = None | |
| torch_device = None | |
| try: | |
| import torch | |
| from core_engine import CRISPRMegaModel, encode_pair, extract_bio_features, cfg | |
| torch_device = torch.device('cpu') | |
| # Find the .pth file | |
| pth_path = None | |
| for candidate in PTH_CANDIDATES: | |
| if os.path.exists(candidate): | |
| pth_path = candidate | |
| break | |
| if pth_path: | |
| torch_model = CRISPRMegaModel() | |
| checkpoint = torch.load(pth_path, map_location=torch_device, weights_only=False) | |
| # Handle wrapped state dicts (checkpoint saves with extra metadata) | |
| if isinstance(checkpoint, dict): | |
| if 'state' in checkpoint: | |
| state_dict = checkpoint['state'] # your Kaggle format | |
| elif 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| state_dict = checkpoint # assume bare state dict | |
| else: | |
| state_dict = checkpoint | |
| torch_model.load_state_dict(state_dict) | |
| torch_model.eval() | |
| print(f"β PyTorch CRISPRMegaModel loaded from: {pth_path}") | |
| else: | |
| warnings.warn("β οΈ mega_model_best.pth not found. PyTorch stage disabled.") | |
| except ImportError as e: | |
| warnings.warn(f"β οΈ PyTorch / core_engine import failed: {e}. Install with: pip install torch") | |
| except Exception as e: | |
| warnings.warn(f"β οΈ PyTorch model load error: {e}. Running without neural embeddings.") | |
| torch_model = None | |
| # ββ Stage 2: AutoGluon ββ | |
| ag_predictor = None | |
| try: | |
| from autogluon.tabular import TabularPredictor | |
| ag_path = None | |
| for candidate in AG_CANDIDATES: | |
| if os.path.isdir(candidate): | |
| ag_path = candidate | |
| break | |
| if ag_path: | |
| ag_predictor = TabularPredictor.load(ag_path) | |
| print(f"β AutoGluon predictor loaded from: {ag_path}") | |
| else: | |
| warnings.warn("β οΈ autogluon_mega/ directory not found. AutoGluon stage disabled.") | |
| except ImportError: | |
| warnings.warn("β οΈ AutoGluon not installed. Install with: pip install autogluon.tabular") | |
| except Exception as e: | |
| warnings.warn(f"β οΈ AutoGluon load error: {e}") | |
| # ββ Status summary ββ | |
| PIPELINE_STATUS = { | |
| "pytorch": "loaded" if torch_model is not None else "unavailable", | |
| "autogluon": "loaded" if ag_predictor is not None else "unavailable", | |
| } | |
| if torch_model and ag_predictor: | |
| PIPELINE_MODE = "live" | |
| print("π LIVE MODE β Full 2-stage pipeline active.") | |
| elif torch_model: | |
| PIPELINE_MODE = "partial-pytorch" | |
| print("β‘ PARTIAL MODE β PyTorch only (no AutoGluon).") | |
| else: | |
| PIPELINE_MODE = "demo" | |
| print("β‘ DEMO MODE β Returning synthetic predictions.") | |
| # βββββββββββββββββββββββββββ PYDANTIC SCHEMAS βββββββββββββββββββββββββββββββ | |
| class GuideRNAInput(BaseModel): | |
| """Input schema: an sgRNA sequence and its candidate off-target site.""" | |
| sgRNA_seq: str = Field( | |
| ..., | |
| min_length=10, | |
| max_length=30, | |
| description="The 20β23nt sgRNA guide sequence (A/T/C/G/U/N/-).", | |
| json_schema_extra={"examples": ["GAGTCCGAGCAGAAGAAGAA"]}, | |
| ) | |
| off_target_seq: str = Field( | |
| ..., | |
| min_length=10, | |
| max_length=30, | |
| description="The candidate off-target DNA site (A/T/C/G/N/-).", | |
| json_schema_extra={"examples": ["GAGTCCAAGCAGAAGAAGAA"]}, | |
| ) | |
| class SafetyScoreResponse(BaseModel): | |
| """Output schema for the safety prediction.""" | |
| sgRNA_seq: str | |
| off_target_seq: str | |
| safety_confidence_percentage: float = Field( | |
| ..., ge=0, le=100, | |
| description="AI-predicted safety confidence (0β100%). Higher = safer.", | |
| ) | |
| status: str = Field( | |
| ..., description="'Safe' (>80%) or 'Risky' (β€80%).", | |
| ) | |
| n_mismatches: int = Field( | |
| ..., description="Number of mismatches between sgRNA and off-target.", | |
| ) | |
| mode: str = Field( | |
| ..., description="Pipeline mode: 'live', 'partial-pytorch', or 'demo'.", | |
| ) | |
| pipeline: dict = Field( | |
| ..., description="Status of each pipeline stage.", | |
| ) | |
| timestamp: str | |
| # βββββββββββββββββββββββββββ INFERENCE HELPERS ββββββββββββββββββββββββββββββ | |
| def _run_pytorch_inference(sgrna: str, offtarget: str) -> np.ndarray: | |
| """Run Stage 1: PyTorch model β 256-dim embedding vector.""" | |
| sg_tok, off_tok, mm_tok = encode_pair(sgrna, offtarget) | |
| sg_t = torch.tensor([sg_tok], dtype=torch.long, device=torch_device) | |
| off_t = torch.tensor([off_tok], dtype=torch.long, device=torch_device) | |
| mm_t = torch.tensor([mm_tok], dtype=torch.long, device=torch_device) | |
| with torch.no_grad(): | |
| output = torch_model(sg_t, off_t, mm_t) | |
| return output['embedding'].cpu().numpy().flatten() # (256,) | |
| def _build_feature_row(sgrna: str, offtarget: str, embeddings: np.ndarray) -> pd.DataFrame: | |
| """Combine 256 neural embeddings + bio features into a single-row DataFrame.""" | |
| # Embedding columns: emb_0 β¦ emb_255 | |
| row = {f'emb_{i}': float(embeddings[i]) for i in range(len(embeddings))} | |
| # Biological features | |
| bio = extract_bio_features(sgrna, offtarget) | |
| row.update(bio) | |
| return pd.DataFrame([row]) | |
| # βββββββββββββββββββββββββββ ENDPOINTS ββββββββββββββββββββββββββββββββββββββ | |
| def dashboard(): | |
| """Premium web dashboard for DeepCRISPR Enterprise.""" | |
| html_path = Path(BASE_DIR) / "templates" / "dashboard.html" | |
| return HTMLResponse(content=html_path.read_text(encoding="utf-8"), status_code=200) | |
| def health_check(): | |
| """Health check and pipeline status.""" | |
| return { | |
| "message": "DeepCRISPR Enterprise API is Live.", | |
| "mode": PIPELINE_MODE, | |
| "pipeline": PIPELINE_STATUS, | |
| } | |
| def predict_safety_score(payload: GuideRNAInput): | |
| """ | |
| **2-Stage AI Pipeline:** | |
| 1. The sgRNA + off-target pair is tokenized and passed through the | |
| PyTorch CRISPRMegaModel (CNN + Transformer + BiLSTM) to extract | |
| 256-dimensional learned embeddings. | |
| 2. The embeddings are combined with 50 hand-crafted biological features | |
| and fed to the AutoGluon TabularPredictor for the final safety score. | |
| **Classification:** Safe (>80%) or Risky (β€80%). | |
| """ | |
| sgrna = payload.sgRNA_seq.strip().upper().replace('U', 'T') | |
| offtarget = payload.off_target_seq.strip().upper().replace('U', 'T') | |
| # ββ Validate characters ββ | |
| valid_chars = re.compile(r'^[ATCGN\-]+$') | |
| if not valid_chars.match(sgrna): | |
| raise HTTPException( | |
| status_code=422, | |
| detail="sgRNA_seq contains invalid characters. Allowed: A, T, C, G, U, N, -", | |
| ) | |
| if not valid_chars.match(offtarget): | |
| raise HTTPException( | |
| status_code=422, | |
| detail="off_target_seq contains invalid characters. Allowed: A, T, C, G, U, N, -", | |
| ) | |
| # ββ Count mismatches for response ββ | |
| sg_padded = sgrna[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N') | |
| off_padded = offtarget[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N') | |
| n_mm = sum(1 for a, b in zip(sg_padded, off_padded) if a != b) | |
| # ββ Stage 1: PyTorch embeddings ββ | |
| if torch_model is not None: | |
| try: | |
| embeddings = _run_pytorch_inference(sgrna, offtarget) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"PyTorch inference failed: {e}", | |
| ) | |
| else: | |
| # Synthetic 256-dim embeddings for demo mode | |
| import hashlib | |
| seed = int(hashlib.md5((sgrna + offtarget).encode()).hexdigest()[:8], 16) | |
| rng = np.random.RandomState(seed) | |
| embeddings = rng.randn(256).astype(np.float32) * 0.1 | |
| # ββ Build feature DataFrame ββ | |
| bio_feats = extract_bio_features(sgrna, offtarget) | |
| row = {f'emb_{i}': float(embeddings[i]) for i in range(len(embeddings))} | |
| row.update(bio_feats) | |
| df_features = pd.DataFrame([row]) | |
| # ββ Stage 2: AutoGluon prediction ββ | |
| if ag_predictor is not None: | |
| try: | |
| proba = ag_predictor.predict_proba(df_features) | |
| if hasattr(proba, 'shape') and len(proba.shape) == 2: | |
| safety_pct = float(proba.iloc[0, 0] * 100) | |
| else: | |
| safety_pct = float(proba.iloc[0] * 100) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"AutoGluon prediction failed: {e}", | |
| ) | |
| elif torch_model is not None: | |
| # Partial mode: use PyTorch off_prob directly | |
| sg_tok, off_tok, mm_tok = encode_pair(sgrna, offtarget) | |
| sg_t = torch.tensor([sg_tok], dtype=torch.long, device=torch_device) | |
| off_t = torch.tensor([off_tok], dtype=torch.long, device=torch_device) | |
| mm_t = torch.tensor([mm_tok], dtype=torch.long, device=torch_device) | |
| with torch.no_grad(): | |
| output = torch_model(sg_t, off_t, mm_t) | |
| safety_pct = float((1 - output['off_prob'].item()) * 100) | |
| else: | |
| # Demo mode: hash-based deterministic score | |
| import hashlib | |
| seed = int(hashlib.md5((sgrna + offtarget).encode()).hexdigest()[:8], 16) | |
| rng = np.random.RandomState(seed) | |
| safety_pct = round(float(rng.uniform(0, 100)), 2) | |
| safety_pct = round(max(0.0, min(100.0, safety_pct)), 2) | |
| status = "Safe" if safety_pct > 80 else "Risky" | |
| return SafetyScoreResponse( | |
| sgRNA_seq=sgrna, | |
| off_target_seq=offtarget, | |
| safety_confidence_percentage=safety_pct, | |
| status=status, | |
| n_mismatches=n_mm, | |
| mode=PIPELINE_MODE, | |
| pipeline=PIPELINE_STATUS, | |
| timestamp=datetime.now(timezone.utc).isoformat(), | |
| ) | |
| # βββββββββββββββββββββββββββ LOCAL SERVER βββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("=" * 60) | |
| print(" DeepCRISPR Enterprise API β 2-Stage Pipeline") | |
| print(" Architected by Mujahid") | |
| print("=" * 60) | |
| print(f" PyTorch: {PIPELINE_STATUS['pytorch']}") | |
| print(f" AutoGluon: {PIPELINE_STATUS['autogluon']}") | |
| print(f" Mode: {PIPELINE_MODE.upper()}") | |
| print(" Starting server β http://127.0.0.1:8000") | |
| print(" Swagger UI β http://127.0.0.1:8000/docs") | |
| print("=" * 60) | |
| uvicorn.run(app, host="127.0.0.1", port=8000) | |