File size: 4,402 Bytes
050c2ea
78a6f01
050c2ea
 
 
 
 
 
003f1c0
99c2921
050c2ea
003f1c0
050c2ea
92c81e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99c2921
92c81e8
003f1c0
050c2ea
92c81e8
003f1c0
99c2921
f8e4d91
050c2ea
99c2921
78a6f01
 
003f1c0
6b14a18
003f1c0
78a6f01
 
050c2ea
99c2921
78a6f01
99c2921
003f1c0
99c2921
92c81e8
 
 
 
f8e4d91
92c81e8
99c2921
92c81e8
99c2921
92c81e8
f8e4d91
003f1c0
 
 
 
 
050c2ea
 
003f1c0
 
 
 
99c2921
f8e4d91
050c2ea
 
6b14a18
050c2ea
99c2921
 
 
050c2ea
99c2921
 
 
 
050c2ea
99c2921
 
 
 
 
 
 
 
f8e4d91
99c2921
 
050c2ea
f8e4d91
 
 
 
 
050c2ea
 
99c2921
050c2ea
003f1c0
 
 
050c2ea
f8e4d91
050c2ea
 
99c2921
 
003f1c0
050c2ea
003f1c0
 
 
 
92c81e8
050c2ea
 
 
99c2921
 
 
050c2ea
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn as nn
import numpy as np
import joblib

# 1. Define the Neural Network Architecture
# Using Sequential to match your .pth file structure exactly
class MedCareDDI_Network(nn.Module):
    def __init__(self, input_dim):
        super(MedCareDDI_Network, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(128, 3) # Output: 0=Major, 1=Minor, 2=Moderate
        )

    def forward(self, x):
        return self.network(x)

# 2. Initialize FastAPI
app = FastAPI(title="MedCare DDI API", version="2.8")

# Configure CORS for your React Frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 3. Load Global Resources
device = torch.device("cpu")
input_dim = 12642  # 6321 * 2

print("Initializing AI Engine...")
try:
    drug_vectors = joblib.load('drug_vectors.pkl')
    model = MedCareDDI_Network(input_dim)
    model.load_state_dict(torch.load('MedCareDDI_DrugBank_Model.pth', map_location=device))
    # Standard global eval mode
    model.eval()
    print("Success: Model and Vectors loaded.")
except Exception as e:
    print(f"FAILED TO START: {str(e)}")

# Clinical Severity Mapping (0=Major, 1=Minor, 2=Moderate)
SEVERITY_MAP = {
    0: "Major",
    1: "Minor",
    2: "Moderate"
}

class DDIRequest(BaseModel):
    drug_a_id: str
    drug_b_id: str

@app.get("/")
def health_check():
    return {"status": "online", "model": "DrugBank_MultiModal_v2.8"}

@app.post("/predict")
@app.post("/predict/")
async def predict_interaction(request: DDIRequest):
    # Standardize input IDs (Strip whitespace and uppercase)
    d1_id = request.drug_a_id.strip().upper()
    d2_id = request.drug_b_id.strip().upper()

    # Safety Check: Does the drug exist in our biological database?
    if d1_id not in drug_vectors or d2_id not in drug_vectors:
        missing = d1_id if d1_id not in drug_vectors else d2_id
        raise HTTPException(status_code=400, detail=f"Drug ID {missing} not in database.")

    try:
        # STEP 1: Forced Symmetry
        # Alphabetical sorting ensures A+B always equals B+A mathematically
        drug_ids = sorted([d1_id, d2_id])
        v1 = drug_vectors[drug_ids[0]]
        v2 = drug_vectors[drug_ids[1]]

        # STEP 2: Vector Preparation
        # Reshape to (1, 12642) to provide a single sample batch
        combined = np.concatenate([v1, v2]).astype(np.float32).reshape(1, -1)
        input_tensor = torch.from_numpy(combined).to(device)

        # STEP 3: CRITICAL FIX for BatchNorm1d Error
        # Ensure the model is in eval mode right before passing the tensor
        model.eval()

        # STEP 4: Inference
        with torch.no_grad():
            output = model(input_tensor)
            # Apply Softmax to get probabilities
            probabilities = torch.nn.functional.softmax(output, dim=1)[0]
            
            predicted_idx = torch.argmax(probabilities).item()
            confidence = probabilities[predicted_idx].item() * 100

        # STEP 5: Format Results
        return {
            "status": "success",
            "drug_a": drug_ids[0],
            "drug_b": drug_ids[1],
            "severity": SEVERITY_MAP[predicted_idx],
            "confidence": f"{confidence:.2f}%",
            "raw_scores": {
                "Major (0)": f"{probabilities[0].item():.4f}",
                "Minor (1)": f"{probabilities[1].item():.4f}",
                "Moderate (2)": f"{probabilities[2].item():.4f}"
            }
        }

    except Exception as e:
        # Log the specific cause of 500 errors to the HF console
        print(f"RUNTIME ERROR during prediction: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal AI Error: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)