medicare-api / app.py
megamind22's picture
Update app.py
f8e4d91 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn as nn
import numpy as np
import joblib
# 1. Define the Neural Network Architecture
# Using Sequential to match your .pth file structure exactly
class MedCareDDI_Network(nn.Module):
def __init__(self, input_dim):
super(MedCareDDI_Network, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 3) # Output: 0=Major, 1=Minor, 2=Moderate
)
def forward(self, x):
return self.network(x)
# 2. Initialize FastAPI
app = FastAPI(title="MedCare DDI API", version="2.8")
# Configure CORS for your React Frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
# 3. Load Global Resources
device = torch.device("cpu")
input_dim = 12642 # 6321 * 2
print("Initializing AI Engine...")
try:
drug_vectors = joblib.load('drug_vectors.pkl')
model = MedCareDDI_Network(input_dim)
model.load_state_dict(torch.load('MedCareDDI_DrugBank_Model.pth', map_location=device))
# Standard global eval mode
model.eval()
print("Success: Model and Vectors loaded.")
except Exception as e:
print(f"FAILED TO START: {str(e)}")
# Clinical Severity Mapping (0=Major, 1=Minor, 2=Moderate)
SEVERITY_MAP = {
0: "Major",
1: "Minor",
2: "Moderate"
}
class DDIRequest(BaseModel):
drug_a_id: str
drug_b_id: str
@app.get("/")
def health_check():
return {"status": "online", "model": "DrugBank_MultiModal_v2.8"}
@app.post("/predict")
@app.post("/predict/")
async def predict_interaction(request: DDIRequest):
# Standardize input IDs (Strip whitespace and uppercase)
d1_id = request.drug_a_id.strip().upper()
d2_id = request.drug_b_id.strip().upper()
# Safety Check: Does the drug exist in our biological database?
if d1_id not in drug_vectors or d2_id not in drug_vectors:
missing = d1_id if d1_id not in drug_vectors else d2_id
raise HTTPException(status_code=400, detail=f"Drug ID {missing} not in database.")
try:
# STEP 1: Forced Symmetry
# Alphabetical sorting ensures A+B always equals B+A mathematically
drug_ids = sorted([d1_id, d2_id])
v1 = drug_vectors[drug_ids[0]]
v2 = drug_vectors[drug_ids[1]]
# STEP 2: Vector Preparation
# Reshape to (1, 12642) to provide a single sample batch
combined = np.concatenate([v1, v2]).astype(np.float32).reshape(1, -1)
input_tensor = torch.from_numpy(combined).to(device)
# STEP 3: CRITICAL FIX for BatchNorm1d Error
# Ensure the model is in eval mode right before passing the tensor
model.eval()
# STEP 4: Inference
with torch.no_grad():
output = model(input_tensor)
# Apply Softmax to get probabilities
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
predicted_idx = torch.argmax(probabilities).item()
confidence = probabilities[predicted_idx].item() * 100
# STEP 5: Format Results
return {
"status": "success",
"drug_a": drug_ids[0],
"drug_b": drug_ids[1],
"severity": SEVERITY_MAP[predicted_idx],
"confidence": f"{confidence:.2f}%",
"raw_scores": {
"Major (0)": f"{probabilities[0].item():.4f}",
"Minor (1)": f"{probabilities[1].item():.4f}",
"Moderate (2)": f"{probabilities[2].item():.4f}"
}
}
except Exception as e:
# Log the specific cause of 500 errors to the HF console
print(f"RUNTIME ERROR during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal AI Error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)