|
|
"""FastAPI Backend for Drug-Target Binding Affinity Prediction (KC-DTA)""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from functools import lru_cache |
|
|
|
|
|
import torch |
|
|
from torch_geometric import data as DATA |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem.rdchem import ValenceType |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field, ConfigDict |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:5173").split(",") |
|
|
|
|
|
|
|
|
MAX_SMILES_LENGTH = 500 |
|
|
MAX_PROTEIN_LENGTH = 5000 |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent / "KCDTA")) |
|
|
from models.cnn import cnn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEQ_VOC = "ACDEFGHIKLMNPQRSTVWXY" |
|
|
SEQ_VOC_SET = frozenset(SEQ_VOC) |
|
|
L = 21 |
|
|
AA_TO_IDX = {aa: idx for idx, aa in enumerate(SEQ_VOC)} |
|
|
|
|
|
|
|
|
_ATOM_SYMBOLS = ('C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', |
|
|
'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', |
|
|
'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', |
|
|
'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb') |
|
|
ATOM_SYMBOL_IDX = {s: i for i, s in enumerate(_ATOM_SYMBOLS)} |
|
|
|
|
|
|
|
|
PERM_INDICES = ((0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0)) |
|
|
|
|
|
|
|
|
_EMPTY_EDGE = None |
|
|
_ZERO_Y = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=50000) |
|
|
def _atom_feat(symbol: str, degree: int, num_hs: int, valence: int, aromatic: bool) -> tuple: |
|
|
"""Cached atom features - returns normalized 78-dim tuple.""" |
|
|
feat = [0.0] * 78 |
|
|
feat[ATOM_SYMBOL_IDX.get(symbol, 43)] = 1.0 |
|
|
feat[44 + min(degree, 10)] = 1.0 |
|
|
feat[55 + min(num_hs, 10)] = 1.0 |
|
|
feat[66 + min(valence, 10)] = 1.0 |
|
|
feat[77] = 1.0 if aromatic else 0.0 |
|
|
s = sum(feat) |
|
|
return tuple(f / s for f in feat) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=10000) |
|
|
def smile_to_graph(smile: str) -> tuple: |
|
|
"""Convert SMILES to molecular graph (cached). Returns (n_atoms, features, edges).""" |
|
|
mol = Chem.MolFromSmiles(smile) |
|
|
if mol is None: |
|
|
raise ValueError(f"Invalid SMILES: {smile}") |
|
|
|
|
|
|
|
|
features = tuple( |
|
|
_atom_feat(a.GetSymbol(), a.GetDegree(), a.GetTotalNumHs(), a.GetValence(ValenceType.IMPLICIT), a.GetIsAromatic()) |
|
|
for a in mol.GetAtoms() |
|
|
) |
|
|
|
|
|
|
|
|
edges = [] |
|
|
for b in mol.GetBonds(): |
|
|
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() |
|
|
edges.extend((i, j, j, i)) |
|
|
|
|
|
return len(features), features, tuple(edges) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=2000) |
|
|
def protein_features(seq: str) -> tuple: |
|
|
"""Cached combined 2D+3D protein features. Returns (flat_2d, flat_3d).""" |
|
|
|
|
|
counts = [0] * L |
|
|
for c in seq: |
|
|
idx = AA_TO_IDX.get(c) |
|
|
if idx is not None: |
|
|
counts[idx] += 1 |
|
|
|
|
|
flat_2d = tuple(counts[i] * counts[j] for i in range(L) for j in range(L)) |
|
|
|
|
|
|
|
|
pro_3d = [0.0] * (L * L * L) |
|
|
seq_len = len(seq) |
|
|
|
|
|
|
|
|
trimer_counts = {} |
|
|
for i in range(seq_len - 2): |
|
|
t = seq[i:i+3] |
|
|
trimer_counts[t] = trimer_counts.get(t, 0) + 1 |
|
|
|
|
|
|
|
|
for trimer, count in trimer_counts.items(): |
|
|
try: |
|
|
idx = tuple(AA_TO_IDX[c] for c in trimer) |
|
|
except KeyError: |
|
|
continue |
|
|
for p in PERM_INDICES: |
|
|
a, b, c = idx[p[0]], idx[p[1]], idx[p[2]] |
|
|
pro_3d[a * L * L + b * L + c] += count |
|
|
|
|
|
|
|
|
max_val = max(pro_3d) if pro_3d else 0 |
|
|
if max_val > 0: |
|
|
pro_3d = tuple(v / max_val for v in pro_3d) |
|
|
else: |
|
|
pro_3d = tuple(pro_3d) |
|
|
|
|
|
return flat_2d, pro_3d |
|
|
|
|
|
|
|
|
def create_graph_data(smiles: str, protein_seq: str, device: torch.device) -> DATA.Data: |
|
|
"""Create PyTorch Geometric Data object directly on device.""" |
|
|
n_atoms, features, edges = smile_to_graph(smiles) |
|
|
flat_2d, flat_3d = protein_features(protein_seq) |
|
|
|
|
|
|
|
|
x = torch.tensor(features, dtype=torch.float32, device=device) |
|
|
|
|
|
if edges: |
|
|
edge_idx = torch.tensor(edges, dtype=torch.long, device=device).view(2, -1) |
|
|
else: |
|
|
edge_idx = _EMPTY_EDGE if _EMPTY_EDGE is not None and _EMPTY_EDGE.device == device else torch.empty((2, 0), dtype=torch.long, device=device) |
|
|
|
|
|
data = DATA.Data(x=x, edge_index=edge_idx, y=_ZERO_Y) |
|
|
data.dcpro = torch.tensor(flat_2d, dtype=torch.float32, device=device).view(1, L, L) |
|
|
data.target = torch.tensor(flat_3d, dtype=torch.float32, device=device).view(1, L, L, L) |
|
|
data.batch = torch.zeros(n_atoms, dtype=torch.long, device=device) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
class AppState: |
|
|
__slots__ = ('model', 'device', 'empty_edge', 'zero_y') |
|
|
def __init__(self): |
|
|
self.model: Optional[cnn] = None |
|
|
self.device: Optional[torch.device] = None |
|
|
self.empty_edge: Optional[torch.Tensor] = None |
|
|
self.zero_y: Optional[torch.Tensor] = None |
|
|
|
|
|
state = AppState() |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
global _EMPTY_EDGE, _ZERO_Y |
|
|
|
|
|
|
|
|
model_path = Path(__file__).parent / "KCDTA" / "model_cnn_kiba.model" |
|
|
state.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
state.model = cnn() |
|
|
state.model.load_state_dict(torch.load(model_path, map_location=state.device, weights_only=True)) |
|
|
state.model.to(state.device) |
|
|
state.model.eval() |
|
|
|
|
|
|
|
|
for param in state.model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
_EMPTY_EDGE = torch.empty((2, 0), dtype=torch.long, device=state.device) |
|
|
_ZERO_Y = torch.zeros(1, device=state.device) |
|
|
|
|
|
logger.info(f"Model loaded on {state.device} with {sum(p.numel() for p in state.model.parameters()):,} parameters") |
|
|
yield |
|
|
|
|
|
|
|
|
state.model = None |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Drug-Target Binding Affinity Prediction API", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan, |
|
|
docs_url="/docs", |
|
|
redoc_url="/redoc", |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=ALLOWED_ORIGINS, |
|
|
allow_credentials=True, |
|
|
allow_methods=["GET", "POST", "HEAD"], |
|
|
allow_headers=["Content-Type", "Authorization"], |
|
|
) |
|
|
|
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
|
model_config = ConfigDict(extra="ignore") |
|
|
smiles: str = Field(..., min_length=1, max_length=MAX_SMILES_LENGTH, description="SMILES representation of the drug molecule") |
|
|
protein_sequence: str = Field(..., min_length=1, max_length=MAX_PROTEIN_LENGTH, description="Amino acid sequence of the target protein") |
|
|
|
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
smiles: str |
|
|
protein_sequence: str |
|
|
binding_affinity: float |
|
|
model_used: str = "KIBA" |
|
|
|
|
|
|
|
|
@app.head("/health") |
|
|
async def health(): |
|
|
return {"status": "healthy", "model_loaded": state.model is not None, "device": str(state.device)} |
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
|
async def predict(request: PredictionRequest): |
|
|
if state.model is None: |
|
|
raise HTTPException(503, "Model not loaded") |
|
|
|
|
|
smiles = request.smiles.strip() |
|
|
seq = request.protein_sequence.strip().upper() |
|
|
|
|
|
|
|
|
invalid_aa = set(seq) - SEQ_VOC_SET |
|
|
if invalid_aa: |
|
|
raise HTTPException(400, f"Invalid amino acids found: {invalid_aa}. Valid: {SEQ_VOC}") |
|
|
|
|
|
|
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
raise HTTPException(400, f"Invalid SMILES string: unable to parse molecule") |
|
|
|
|
|
|
|
|
if mol.GetNumAtoms() == 0: |
|
|
raise HTTPException(400, "SMILES represents an empty molecule") |
|
|
|
|
|
try: |
|
|
data = create_graph_data(smiles, seq, state.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
affinity = state.model(data).item() |
|
|
except Exception as e: |
|
|
logger.error(f"Prediction failed: {e}") |
|
|
raise HTTPException(500, "Prediction failed due to internal error") |
|
|
|
|
|
return PredictionResponse(smiles=smiles, protein_sequence=seq, binding_affinity=round(affinity, 4)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run( |
|
|
"main:app", |
|
|
host="0.0.0.0", |
|
|
port=int(os.getenv("PORT", 7860)), |
|
|
log_level=os.getenv("LOG_LEVEL", "info"), |
|
|
factory=False, |
|
|
) |
|
|
|