from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch import torch.nn as nn import numpy as np import joblib # 1. Define the Neural Network Architecture # Using Sequential to match your .pth file structure exactly class MedCareDDI_Network(nn.Module): def __init__(self, input_dim): super(MedCareDDI_Network, self).__init__() self.network = nn.Sequential( nn.Linear(input_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 3) # Output: 0=Major, 1=Minor, 2=Moderate ) def forward(self, x): return self.network(x) # 2. Initialize FastAPI app = FastAPI(title="MedCare DDI API", version="2.8") # Configure CORS for your React Frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) # 3. Load Global Resources device = torch.device("cpu") input_dim = 12642 # 6321 * 2 print("Initializing AI Engine...") try: drug_vectors = joblib.load('drug_vectors.pkl') model = MedCareDDI_Network(input_dim) model.load_state_dict(torch.load('MedCareDDI_DrugBank_Model.pth', map_location=device)) # Standard global eval mode model.eval() print("Success: Model and Vectors loaded.") except Exception as e: print(f"FAILED TO START: {str(e)}") # Clinical Severity Mapping (0=Major, 1=Minor, 2=Moderate) SEVERITY_MAP = { 0: "Major", 1: "Minor", 2: "Moderate" } class DDIRequest(BaseModel): drug_a_id: str drug_b_id: str @app.get("/") def health_check(): return {"status": "online", "model": "DrugBank_MultiModal_v2.8"} @app.post("/predict") @app.post("/predict/") async def predict_interaction(request: DDIRequest): # Standardize input IDs (Strip whitespace and uppercase) d1_id = request.drug_a_id.strip().upper() d2_id = request.drug_b_id.strip().upper() # Safety Check: Does the drug exist in our biological database? if d1_id not in drug_vectors or d2_id not in drug_vectors: missing = d1_id if d1_id not in drug_vectors else d2_id raise HTTPException(status_code=400, detail=f"Drug ID {missing} not in database.") try: # STEP 1: Forced Symmetry # Alphabetical sorting ensures A+B always equals B+A mathematically drug_ids = sorted([d1_id, d2_id]) v1 = drug_vectors[drug_ids[0]] v2 = drug_vectors[drug_ids[1]] # STEP 2: Vector Preparation # Reshape to (1, 12642) to provide a single sample batch combined = np.concatenate([v1, v2]).astype(np.float32).reshape(1, -1) input_tensor = torch.from_numpy(combined).to(device) # STEP 3: CRITICAL FIX for BatchNorm1d Error # Ensure the model is in eval mode right before passing the tensor model.eval() # STEP 4: Inference with torch.no_grad(): output = model(input_tensor) # Apply Softmax to get probabilities probabilities = torch.nn.functional.softmax(output, dim=1)[0] predicted_idx = torch.argmax(probabilities).item() confidence = probabilities[predicted_idx].item() * 100 # STEP 5: Format Results return { "status": "success", "drug_a": drug_ids[0], "drug_b": drug_ids[1], "severity": SEVERITY_MAP[predicted_idx], "confidence": f"{confidence:.2f}%", "raw_scores": { "Major (0)": f"{probabilities[0].item():.4f}", "Minor (1)": f"{probabilities[1].item():.4f}", "Moderate (2)": f"{probabilities[2].item():.4f}" } } except Exception as e: # Log the specific cause of 500 errors to the HF console print(f"RUNTIME ERROR during prediction: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal AI Error: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)