AbdoIR's picture
Update main.py
5fb77dc verified
"""FastAPI Backend for Drug-Target Binding Affinity Prediction (KC-DTA)"""
import os
import sys
import logging
from pathlib import Path
from functools import lru_cache
import torch
from torch_geometric import data as DATA
from rdkit import Chem
from rdkit.Chem.rdchem import ValenceType
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, ConfigDict
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# Configuration
# ============================================================================
# CORS: Use environment variable for allowed origins (comma-separated)
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:5173").split(",")
# Input validation limits (prevent DoS)
MAX_SMILES_LENGTH = 500
MAX_PROTEIN_LENGTH = 5000
# Add KCDTA to path
sys.path.insert(0, str(Path(__file__).parent / "KCDTA"))
from models.cnn import cnn
# ============================================================================
# Pre-computed Constants
# ============================================================================
SEQ_VOC = "ACDEFGHIKLMNPQRSTVWXY"
SEQ_VOC_SET = frozenset(SEQ_VOC) # Frozenset for O(1) lookup
L = 21 # len(SEQ_VOC) - hardcoded to avoid function call
AA_TO_IDX = {aa: idx for idx, aa in enumerate(SEQ_VOC)}
# Pre-compute atom symbol lookup (44 symbols)
_ATOM_SYMBOLS = ('C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag',
'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni',
'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb')
ATOM_SYMBOL_IDX = {s: i for i, s in enumerate(_ATOM_SYMBOLS)}
# Pre-compute all 6 permutation index tuples for 3-mers
PERM_INDICES = ((0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0))
# Pre-allocated reusable tensors (will be set on startup with correct device)
_EMPTY_EDGE = None
_ZERO_Y = None
# ============================================================================
# Optimized Feature Extraction
# ============================================================================
@lru_cache(maxsize=50000)
def _atom_feat(symbol: str, degree: int, num_hs: int, valence: int, aromatic: bool) -> tuple:
"""Cached atom features - returns normalized 78-dim tuple."""
feat = [0.0] * 78
feat[ATOM_SYMBOL_IDX.get(symbol, 43)] = 1.0
feat[44 + min(degree, 10)] = 1.0
feat[55 + min(num_hs, 10)] = 1.0
feat[66 + min(valence, 10)] = 1.0
feat[77] = 1.0 if aromatic else 0.0
s = sum(feat)
return tuple(f / s for f in feat)
@lru_cache(maxsize=10000)
def smile_to_graph(smile: str) -> tuple:
"""Convert SMILES to molecular graph (cached). Returns (n_atoms, features, edges)."""
mol = Chem.MolFromSmiles(smile)
if mol is None:
raise ValueError(f"Invalid SMILES: {smile}")
# Extract features using cached atom_feat
features = tuple(
_atom_feat(a.GetSymbol(), a.GetDegree(), a.GetTotalNumHs(), a.GetValence(ValenceType.IMPLICIT), a.GetIsAromatic())
for a in mol.GetAtoms()
)
# Build edge list - flat tuple for faster tensor creation
edges = []
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
edges.extend((i, j, j, i)) # Both directions
return len(features), features, tuple(edges)
@lru_cache(maxsize=2000)
def protein_features(seq: str) -> tuple:
"""Cached combined 2D+3D protein features. Returns (flat_2d, flat_3d)."""
# 2D: Cartesian product of amino acid counts
counts = [0] * L
for c in seq:
idx = AA_TO_IDX.get(c)
if idx is not None:
counts[idx] += 1
flat_2d = tuple(counts[i] * counts[j] for i in range(L) for j in range(L))
# 3D: K-mers with permutations
pro_3d = [0.0] * (L * L * L)
seq_len = len(seq)
# Count trimers in one pass
trimer_counts = {}
for i in range(seq_len - 2):
t = seq[i:i+3]
trimer_counts[t] = trimer_counts.get(t, 0) + 1
# Fill 3D matrix
for trimer, count in trimer_counts.items():
try:
idx = tuple(AA_TO_IDX[c] for c in trimer)
except KeyError:
continue
for p in PERM_INDICES:
a, b, c = idx[p[0]], idx[p[1]], idx[p[2]]
pro_3d[a * L * L + b * L + c] += count
# Normalize
max_val = max(pro_3d) if pro_3d else 0
if max_val > 0:
pro_3d = tuple(v / max_val for v in pro_3d)
else:
pro_3d = tuple(pro_3d)
return flat_2d, pro_3d
def create_graph_data(smiles: str, protein_seq: str, device: torch.device) -> DATA.Data:
"""Create PyTorch Geometric Data object directly on device."""
n_atoms, features, edges = smile_to_graph(smiles)
flat_2d, flat_3d = protein_features(protein_seq)
# Create tensors on device
x = torch.tensor(features, dtype=torch.float32, device=device)
if edges:
edge_idx = torch.tensor(edges, dtype=torch.long, device=device).view(2, -1)
else:
edge_idx = _EMPTY_EDGE if _EMPTY_EDGE is not None and _EMPTY_EDGE.device == device else torch.empty((2, 0), dtype=torch.long, device=device)
data = DATA.Data(x=x, edge_index=edge_idx, y=_ZERO_Y)
data.dcpro = torch.tensor(flat_2d, dtype=torch.float32, device=device).view(1, L, L)
data.target = torch.tensor(flat_3d, dtype=torch.float32, device=device).view(1, L, L, L)
data.batch = torch.zeros(n_atoms, dtype=torch.long, device=device)
return data
# ============================================================================
# FastAPI Application
# ============================================================================
from typing import Optional
class AppState:
__slots__ = ('model', 'device', 'empty_edge', 'zero_y') # Slots for memory efficiency
def __init__(self):
self.model: Optional[cnn] = None
self.device: Optional[torch.device] = None
self.empty_edge: Optional[torch.Tensor] = None
self.zero_y: Optional[torch.Tensor] = None
state = AppState()
@asynccontextmanager
async def lifespan(app: FastAPI):
global _EMPTY_EDGE, _ZERO_Y
# Startup: Load model with optimizations
model_path = Path(__file__).parent / "KCDTA" / "model_cnn_kiba.model"
state.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model
state.model = cnn()
state.model.load_state_dict(torch.load(model_path, map_location=state.device, weights_only=True))
state.model.to(state.device)
state.model.eval() # Set to evaluation mode (disables dropout)
# Freeze parameters for inference (additional optimization)
for param in state.model.parameters():
param.requires_grad = False
# Pre-allocate reusable tensors
_EMPTY_EDGE = torch.empty((2, 0), dtype=torch.long, device=state.device)
_ZERO_Y = torch.zeros(1, device=state.device)
logger.info(f"Model loaded on {state.device} with {sum(p.numel() for p in state.model.parameters()):,} parameters")
yield
# Shutdown
state.model = None
app = FastAPI(
title="Drug-Target Binding Affinity Prediction API",
version="1.0.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
)
# Secure CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "HEAD"],
allow_headers=["Content-Type", "Authorization"],
)
class PredictionRequest(BaseModel):
model_config = ConfigDict(extra="ignore")
smiles: str = Field(..., min_length=1, max_length=MAX_SMILES_LENGTH, description="SMILES representation of the drug molecule")
protein_sequence: str = Field(..., min_length=1, max_length=MAX_PROTEIN_LENGTH, description="Amino acid sequence of the target protein")
class PredictionResponse(BaseModel):
smiles: str
protein_sequence: str
binding_affinity: float
model_used: str = "KIBA"
@app.head("/health")
async def health():
return {"status": "healthy", "model_loaded": state.model is not None, "device": str(state.device)}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
if state.model is None:
raise HTTPException(503, "Model not loaded")
smiles = request.smiles.strip()
seq = request.protein_sequence.strip().upper()
# Fast validation using pre-computed frozenset
invalid_aa = set(seq) - SEQ_VOC_SET
if invalid_aa:
raise HTTPException(400, f"Invalid amino acids found: {invalid_aa}. Valid: {SEQ_VOC}")
# Validate SMILES (this also caches valid molecules in RDKit)
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise HTTPException(400, f"Invalid SMILES string: unable to parse molecule")
# Additional molecule validation
if mol.GetNumAtoms() == 0:
raise HTTPException(400, "SMILES represents an empty molecule")
try:
data = create_graph_data(smiles, seq, state.device)
with torch.inference_mode():
affinity = state.model(data).item()
except Exception as e:
logger.error(f"Prediction failed: {e}")
raise HTTPException(500, "Prediction failed due to internal error")
return PredictionResponse(smiles=smiles, protein_sequence=seq, binding_affinity=round(affinity, 4))
if __name__ == "__main__":
import uvicorn
# Hugging Face Spaces requires port 7860 and a single worker
uvicorn.run(
"main:app", # Entrypoint for Hugging Face Spaces
host="0.0.0.0",
port=int(os.getenv("PORT", 7860)),
log_level=os.getenv("LOG_LEVEL", "info"),
factory=False,
)