Spaces:
Running
Running
File size: 13,108 Bytes
7fb740b e87fea1 7fb740b e87fea1 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b a712e78 7fb740b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 |
"""
Phase 2: FastAPI Backend for BioDiscovery Search
Fixes applied:
- Shared config import (no duplication)
- Model caching at startup (not per-request)
- Proper error handling
- Uses pre-computed PCA from Qdrant payloads
- Valid dummy sequences instead of "M" * 10
"""
import os
os.environ["DGL_DISABLE_GRAPHBOLT"] = "1"
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import warnings
import pickle
from typing import Optional, List
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from qdrant_client import QdrantClient
from DeepPurpose import utils, DTI as dp_models
warnings.filterwarnings("ignore")
# Import shared config
from config import (
BEST_MODEL_RUN, MODEL_CONFIG,
QDRANT_HOST, QDRANT_PORT, COLLECTION_NAME, METRICS,
VALID_DUMMY_DRUG, VALID_DUMMY_TARGET
)
app = FastAPI(title="BioDiscovery API", version="2.0")
# CORS for frontend
# Allow generic access for deployment - in production restrict this to your Vercel domain
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- GLOBAL STATE (loaded once at startup) ---
_model = None
_qdrant = None
_device = None
class SearchRequest(BaseModel):
query: str
type: str # "drug" (SMILES) or "target" (Sequence) or "text" (plain text search)
limit: int = 20
class PointsRequest(BaseModel):
limit: int = 500
view: str = "combined" # "drug", "target", or "combined"
@app.on_event("startup")
async def load_resources():
"""Load model and connect to Qdrant at startup (cached)."""
global _model, _qdrant, _device
print("[STARTUP] Loading DeepPurpose model...")
# Load config
config_path = os.path.join(BEST_MODEL_RUN, "config.pkl")
if os.path.exists(config_path):
with open(config_path, "rb") as f:
config = pickle.load(f)
# Override result_folder to current path (old path may be stale)
config["result_folder"] = BEST_MODEL_RUN
else:
config = utils.generate_config(
drug_encoding=MODEL_CONFIG["drug_encoding"],
target_encoding=MODEL_CONFIG["target_encoding"],
cls_hidden_dims=MODEL_CONFIG["cls_hidden_dims"],
train_epoch=1, LR=1e-4, batch_size=256,
result_folder=BEST_MODEL_RUN
)
_model = dp_models.model_initialize(**config)
model_path = os.path.join(BEST_MODEL_RUN, "model.pt")
if os.path.exists(model_path):
_model.load_pretrained(model_path)
print(f"[STARTUP] Model loaded from {model_path}")
else:
print(f"[WARNING] No model.pt found at {model_path}")
# CRITICAL FIX: Override DeepPurpose's global device variable
# The encoders.py uses a module-level `device = torch.device('cuda' if...)`
# and the MLP forward does `v = v.float().to(device)` using that global!
import DeepPurpose.encoders as dp_encoders
_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dp_encoders.device = _device # Override the global
print(f"[STARTUP] Using device: {_device}")
# Ensure model is on the correct device
_model.model = _model.model.to(_device)
_model.model.eval()
print("[STARTUP] Connecting to Qdrant...")
try:
_qdrant = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=10)
collections = _qdrant.get_collections()
print(f"[STARTUP] Connected. Collections: {[c.name for c in collections.collections]}")
except Exception as e:
print(f"[WARNING] Qdrant connection failed: {e}")
_qdrant = None
print("[STARTUP] Ready!")
def encode_query(query: str, query_type: str) -> List[float]:
"""Encode a single drug/target query into a vector using direct encoding."""
if not _model:
raise HTTPException(status_code=503, detail="Model not initialized")
try:
if query_type == "drug":
# Direct Morgan fingerprint encoding (avoid data_process)
from DeepPurpose.utils import smiles2morgan
from rdkit import Chem
import numpy as np
# Validate SMILES
mol = Chem.MolFromSmiles(query)
if mol is None:
raise ValueError(f"Invalid SMILES: {query}")
# Get Morgan fingerprint
morgan_fp = smiles2morgan(query, radius=2, nBits=1024)
if morgan_fp is None:
raise ValueError(f"Failed to compute Morgan fingerprint for: {query}")
# Convert to tensor and encode through model's drug encoder
v_d = torch.tensor(np.array([morgan_fp]), dtype=torch.float32)
with torch.no_grad():
vector = _model.model.model_drug(v_d).cpu().numpy()[0].tolist()
return vector
elif query_type == "target":
# Direct CNN target encoding
from DeepPurpose.utils import trans_protein
import numpy as np
# Encode protein sequence
target_encoding = trans_protein(query)
if target_encoding is None:
raise ValueError(f"Failed to encode protein sequence")
# CNN expects [batch, seq_len] input, max_len=1000 in default config
MAX_SEQ_LEN = 1000
if len(target_encoding) > MAX_SEQ_LEN:
target_encoding = target_encoding[:MAX_SEQ_LEN]
else:
target_encoding = target_encoding + [0] * (MAX_SEQ_LEN - len(target_encoding))
v_p = torch.tensor(np.array([target_encoding]), dtype=torch.long)
with torch.no_grad():
vector = _model.model.model_protein(v_p).cpu().numpy()[0].tolist()
return vector
else:
raise HTTPException(status_code=400, detail="type must be 'drug' or 'target'")
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Encoding failed: {str(e)}")
@app.post("/api/search")
async def search_vectors(req: SearchRequest):
"""Search for similar drugs/targets."""
if not _qdrant:
raise HTTPException(status_code=503, detail="Qdrant not connected")
# Text search - just filter by payload, no encoding needed
if req.type == "text":
return await text_search(req.query, req.limit)
# Vector search - encode and search
try:
vector = encode_query(req.query, req.type)
except Exception as e:
# Fallback to text search if encoding fails
print(f"Encoding failed ({e}), falling back to text search")
return await text_search(req.query, req.limit)
try:
hits = _qdrant.search(
collection_name=COLLECTION_NAME,
query_vector=(req.type, vector), # Named vector
limit=req.limit
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
results = []
for hit in hits:
results.append({
"id": hit.id,
"score": hit.score,
"smiles": hit.payload.get("smiles"),
"target_seq": hit.payload.get("target_seq", "")[:100] + "...",
"label": hit.payload.get("label_true"),
"affinity_class": hit.payload.get("affinity_class"),
})
return {"results": results, "query_type": req.type, "count": len(results)}
async def text_search(query: str, limit: int = 20):
"""Text-based search through payloads (fallback when encoding fails)."""
try:
# Scroll through and filter by SMILES containing the query
res, _ = _qdrant.scroll(
collection_name=COLLECTION_NAME,
limit=500, # Get more to filter through
with_payload=True,
with_vectors=False
)
# Filter results that match query in SMILES or other fields
query_lower = query.lower()
results = []
for point in res:
smiles = point.payload.get("smiles", "").lower()
# Match if query is substring of SMILES or SMILES contains query
if query_lower in smiles:
results.append({
"id": point.id,
"score": 0.95 if query_lower == smiles else 0.8, # Higher score for exact match
"smiles": point.payload.get("smiles"),
"target_seq": point.payload.get("target_seq", "")[:100] + "...",
"label": point.payload.get("label_true"),
"affinity_class": point.payload.get("affinity_class"),
})
if len(results) >= limit:
break
return {"results": results, "query_type": "text", "count": len(results)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Text search failed: {str(e)}")
@app.get("/api/points")
async def get_visualization_points(limit: int = 500, view: str = "combined"):
"""Get points with pre-computed PCA for 3D visualization."""
if not _qdrant:
raise HTTPException(status_code=503, detail="Qdrant not connected")
try:
# Use scroll to get points (more efficient than search for bulk)
res, _ = _qdrant.scroll(
collection_name=COLLECTION_NAME,
limit=limit,
with_vectors=False, # Don't need raw vectors, use PCA from payload
with_payload=True
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Scroll failed: {str(e)}")
# Map view to correct PCA key
pca_key = f"pca_{view}" if view in ["drug", "target", "combined"] else "pca_combined"
points = []
for point in res:
pca = point.payload.get(pca_key, [0, 0, 0])
# Determine color based on affinity class
affinity_class = point.payload.get("affinity_class", "low")
color = {
"high": "#10b981", # Green
"medium": "#f59e0b", # Amber
"low": "#64748b" # Slate
}.get(affinity_class, "#64748b")
points.append({
"id": point.id,
"x": pca[0] if len(pca) > 0 else 0,
"y": pca[1] if len(pca) > 1 else 0,
"z": pca[2] if len(pca) > 2 else 0,
"color": color,
"name": (point.payload.get("smiles") or "Unknown")[:15] + "...",
"affinity": point.payload.get("label_true", 0),
"affinity_class": affinity_class,
"smiles": point.payload.get("smiles"),
})
return {
"points": points,
"metrics": {
"activeMolecules": len(points),
"clusters": 3, # high/medium/low
"avgConfidence": METRICS.get("BindingDB_Kd", {}).get("CI", 0.80),
},
"view": view,
}
@app.get("/health")
def health():
"""Health check endpoint."""
return {
"status": "ok",
"model_loaded": _model is not None,
"qdrant_connected": _qdrant is not None,
"metrics": METRICS,
}
@app.get("/api/stats")
async def get_collection_stats():
"""Get real statistics from Qdrant collection for the data page."""
if not _qdrant:
raise HTTPException(status_code=503, detail="Qdrant not connected")
try:
collection_info = _qdrant.get_collection(collection_name=COLLECTION_NAME)
total_vectors = collection_info.vectors_count
# Sample to count affinity classes
sample, _ = _qdrant.scroll(
collection_name=COLLECTION_NAME,
limit=1000,
with_payload=["affinity_class", "smiles", "target_id"],
with_vectors=False
)
unique_drugs = len(set(p.payload.get("smiles", "") for p in sample if p.payload.get("smiles")))
unique_targets = len(set(p.payload.get("target_id", "") for p in sample if p.payload.get("target_id")))
affinity_counts = {}
for p in sample:
aff = p.payload.get("affinity_class", "unknown")
affinity_counts[aff] = affinity_counts.get(aff, 0) + 1
return {
"total_vectors": total_vectors,
"sample_size": len(sample),
"unique_drugs_sampled": unique_drugs,
"unique_targets_sampled": unique_targets,
"affinity_distribution": affinity_counts,
"collection_name": COLLECTION_NAME,
"status": collection_info.status.value if hasattr(collection_info.status, 'value') else str(collection_info.status),
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Stats fetch failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|