Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import joblib | |
| # 1. Define the Neural Network Architecture | |
| # Using Sequential to match your .pth file structure exactly | |
| class MedCareDDI_Network(nn.Module): | |
| def __init__(self, input_dim): | |
| super(MedCareDDI_Network, self).__init__() | |
| self.network = nn.Sequential( | |
| nn.Linear(input_dim, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, 3) # Output: 0=Major, 1=Minor, 2=Moderate | |
| ) | |
| def forward(self, x): | |
| return self.network(x) | |
| # 2. Initialize FastAPI | |
| app = FastAPI(title="MedCare DDI API", version="2.8") | |
| # Configure CORS for your React Frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 3. Load Global Resources | |
| device = torch.device("cpu") | |
| input_dim = 12642 # 6321 * 2 | |
| print("Initializing AI Engine...") | |
| try: | |
| drug_vectors = joblib.load('drug_vectors.pkl') | |
| model = MedCareDDI_Network(input_dim) | |
| model.load_state_dict(torch.load('MedCareDDI_DrugBank_Model.pth', map_location=device)) | |
| # Standard global eval mode | |
| model.eval() | |
| print("Success: Model and Vectors loaded.") | |
| except Exception as e: | |
| print(f"FAILED TO START: {str(e)}") | |
| # Clinical Severity Mapping (0=Major, 1=Minor, 2=Moderate) | |
| SEVERITY_MAP = { | |
| 0: "Major", | |
| 1: "Minor", | |
| 2: "Moderate" | |
| } | |
| class DDIRequest(BaseModel): | |
| drug_a_id: str | |
| drug_b_id: str | |
| def health_check(): | |
| return {"status": "online", "model": "DrugBank_MultiModal_v2.8"} | |
| async def predict_interaction(request: DDIRequest): | |
| # Standardize input IDs (Strip whitespace and uppercase) | |
| d1_id = request.drug_a_id.strip().upper() | |
| d2_id = request.drug_b_id.strip().upper() | |
| # Safety Check: Does the drug exist in our biological database? | |
| if d1_id not in drug_vectors or d2_id not in drug_vectors: | |
| missing = d1_id if d1_id not in drug_vectors else d2_id | |
| raise HTTPException(status_code=400, detail=f"Drug ID {missing} not in database.") | |
| try: | |
| # STEP 1: Forced Symmetry | |
| # Alphabetical sorting ensures A+B always equals B+A mathematically | |
| drug_ids = sorted([d1_id, d2_id]) | |
| v1 = drug_vectors[drug_ids[0]] | |
| v2 = drug_vectors[drug_ids[1]] | |
| # STEP 2: Vector Preparation | |
| # Reshape to (1, 12642) to provide a single sample batch | |
| combined = np.concatenate([v1, v2]).astype(np.float32).reshape(1, -1) | |
| input_tensor = torch.from_numpy(combined).to(device) | |
| # STEP 3: CRITICAL FIX for BatchNorm1d Error | |
| # Ensure the model is in eval mode right before passing the tensor | |
| model.eval() | |
| # STEP 4: Inference | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| # Apply Softmax to get probabilities | |
| probabilities = torch.nn.functional.softmax(output, dim=1)[0] | |
| predicted_idx = torch.argmax(probabilities).item() | |
| confidence = probabilities[predicted_idx].item() * 100 | |
| # STEP 5: Format Results | |
| return { | |
| "status": "success", | |
| "drug_a": drug_ids[0], | |
| "drug_b": drug_ids[1], | |
| "severity": SEVERITY_MAP[predicted_idx], | |
| "confidence": f"{confidence:.2f}%", | |
| "raw_scores": { | |
| "Major (0)": f"{probabilities[0].item():.4f}", | |
| "Minor (1)": f"{probabilities[1].item():.4f}", | |
| "Moderate (2)": f"{probabilities[2].item():.4f}" | |
| } | |
| } | |
| except Exception as e: | |
| # Log the specific cause of 500 errors to the HF console | |
| print(f"RUNTIME ERROR during prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal AI Error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |