asritha22bce's picture
Update main.py
c689044 verified
"""
FastAPI Backend for Drug-Food Interaction Prediction System (Updated)
This backend integrates with:
1. NCI Chemical Identifier Resolver for drug canonical SMILES
2. RDKit for molecular descriptor calculation
3. USDA FoodData Central API for food nutrients
4. Pre-trained XGBoost models for interaction prediction
Updated to include additional nutrient features:
- Vitamin C (mg), Vitamin D (µg), Vitamin B12 (µg), Vitamin B6 (mg)
- Vitamin A (µg), Vitamin E (mg), Vitamin K (µg), Folate (µg)
- Saturated Fat (g), Monounsaturated Fat (g), Polyunsaturated Fat (g)
- Cholesterol (mg)
Required models: xgb_model.joblib, label_encoder.joblib, feature_order.joblib
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
from dotenv import load_dotenv
load_dotenv()
import requests
import numpy as np
import pandas as pd
import joblib
import asyncio
import logging
from typing import Dict, List, Optional
import aiohttp
from urllib.parse import quote
from fastapi.staticfiles import StaticFiles
from huggingface_hub import hf_hub_download
import joblib
import logging
logger = logging.getLogger(__name__)
loaded_models = {}
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Drug-Food Interaction Prediction API",
description="API for predicting drug-food interactions using molecular descriptors and comprehensive nutritional data",
version="2.0.0"
)
app.mount("/", StaticFiles(directory="frontend", html=True), name="static")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API Configuration
NCI_CONFIG = {
"baseUrl": "https://cactus.nci.nih.gov/chemical/structure",
"timeout": 5000
}
USDA_CONFIG = {
"key": os.getenv("API_KEY"),
"baseUrl": "https://api.nal.usda.gov/fdc/v1/",
"searchEndpoint": "foods/search",
"detailEndpoint": "food",
"timeout": 8000
}
# Request/Response Models
class DrugRequest(BaseModel):
drug_name: str
class FoodRequest(BaseModel):
food_name: str
class InteractionRequest(BaseModel):
drug_name: str
food_name: str
class MolecularDescriptors(BaseModel):
MolWt: float
LogP: float
HBA: int
HBD: int
TPSA: float
RotBonds: int
RingCount: int
FractionCSP3: float
BalabanJ: float
BertzCT: float
# Fingerprint features FP_0 to FP_2047 would be added here
class FoodNutrients(BaseModel):
# Basic macronutrients
Fat: float
Carbohydrates: float
Protein: float
# Vitamins (with proper units)
Vitamin_C_mg: float # Vitamin C (mg)
Vitamin_D_ug: float # Vitamin D (µg)
Vitamin_B12_ug: float # Vitamin B12 (µg)
Vitamin_B6_mg: float # Vitamin B6 (mg)
Vitamin_A_ug: float # Vitamin A (µg)
Vitamin_E_mg: float # Vitamin E (mg)
Vitamin_K_ug: float # Vitamin K (µg)
Folate_ug: float # Folate (µg)
# Minerals
Calcium: float
Iron: float
Magnesium: float
Potassium: float
Sodium: float
Zinc: float
# Fat breakdown
Saturated_Fat_g: float # Saturated Fat (g)
Monounsaturated_Fat_g: float # Monounsaturated Fat (g)
Polyunsaturated_Fat_g: float # Polyunsaturated Fat (g)
Cholesterol_mg: float # Cholesterol (mg)
class InteractionResult(BaseModel):
effect: str # 'harmful', 'negative', 'no effect', 'positive', 'possible'
confidence: float
explanation: str
drug_properties: MolecularDescriptors
food_nutrients: FoodNutrients
# Global variables for loaded models
loaded_models = {}
def load_models():
try:
# Download from Hugging Face Hub dynamically
loaded_models['xgb_model'] = joblib.load(
hf_hub_download(repo_id="asritha22bce/FoodDrugInteraction", filename="models/xgb_model.joblib")
)
loaded_models['label_encoder'] = joblib.load(
hf_hub_download(repo_id="asritha22bce/FoodDrugInteraction", filename="models/label_encoder.joblib")
)
loaded_models['feature_order'] = joblib.load(
hf_hub_download(repo_id="asritha22bce/FoodDrugInteraction", filename="models/feature_order.joblib")
)
logger.info("Models loaded successfully from Hugging Face Hub")
return True
except Exception as e:
logger.error(f"Error loading models: {e}")
return False
# Load models on startup
@app.on_event("startup")
async def startup_event():
"""Initialize models on startup"""
success = load_models()
if not success:
logger.warning("Failed to load models. Predictions will use fallback logic.")
async def get_canonical_smiles(drug_name: str) -> str:
"""Get canonical SMILES from NCI Chemical Identifier Resolver"""
try:
# URL encode the drug name
encoded_name = quote(drug_name)
url = f"{NCI_CONFIG['baseUrl']}/{encoded_name}/canonical_smiles"
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as session:
async with session.get(url) as response:
if response.status == 200:
smiles = await response.text()
return smiles.strip()
else:
raise HTTPException(status_code=404, detail=f"Drug '{drug_name}' not found in NCI database")
except asyncio.TimeoutError:
raise HTTPException(status_code=408, detail="NCI API timeout")
except Exception as e:
logger.error(f"Error getting canonical SMILES for {drug_name}: {e}")
raise HTTPException(status_code=500, detail=f"Error retrieving drug information: {str(e)}")
def calculate_molecular_descriptors(smiles: str) -> Dict:
"""Calculate molecular descriptors using RDKit"""
try:
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, Lipinski
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprint
# Create molecule from SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES string")
# Calculate basic descriptors
descriptors = {
'MolWt': Descriptors.ExactMolWt(mol),
'LogP': Crippen.MolLogP(mol),
'HBA': Lipinski.NumHAcceptors(mol),
'HBD': Lipinski.NumHDonors(mol),
'TPSA': Descriptors.TPSA(mol),
'RotBonds': Lipinski.NumRotatableBonds(mol),
'RingCount': Lipinski.RingCount(mol),
'FractionCSP3': Lipinski.FractionCsp3(mol),
'BalabanJ': Descriptors.BalabanJ(mol),
'BertzCT': Descriptors.BertzCT(mol)
}
# Calculate Morgan fingerprint (FP_0 to FP_2047)
fp = GetMorganFingerprint(mol, radius=2, nBits=2048)
fp_array = np.zeros(2048)
# Convert fingerprint to array
for idx, val in fp.GetNonzeroElements().items():
if idx < 2048:
fp_array[idx] = val
# Add fingerprint features
for i in range(2048):
descriptors[f'FP_{i}'] = fp_array[i]
return descriptors
except ImportError:
# Fallback if RDKit is not available
logger.warning("RDKit not available, using mock descriptors")
return get_mock_molecular_descriptors()
except Exception as e:
logger.error(f"Error calculating molecular descriptors: {e}")
return get_mock_molecular_descriptors()
def get_mock_molecular_descriptors() -> Dict:
"""Mock molecular descriptors for testing"""
descriptors = {
'MolWt': np.random.uniform(100, 500),
'LogP': np.random.uniform(-2, 5),
'HBA': np.random.randint(1, 10),
'HBD': np.random.randint(0, 5),
'TPSA': np.random.uniform(20, 140),
'RotBonds': np.random.randint(0, 10),
'RingCount': np.random.randint(0, 4),
'FractionCSP3': np.random.uniform(0, 1),
'BalabanJ': np.random.uniform(0.5, 2.5),
'BertzCT': np.random.uniform(100, 1000)
}
# Add mock fingerprint features
for i in range(2048):
descriptors[f'FP_{i}'] = np.random.randint(0, 2)
return descriptors
async def search_food_nutrients(food_name: str) -> Dict:
"""Search for food and get comprehensive nutritional information from USDA API"""
try:
# Search for food
search_url = f"{USDA_CONFIG['baseUrl']}{USDA_CONFIG['searchEndpoint']}"
search_params = {
"api_key": USDA_CONFIG['key'],
"query": food_name,
"pageSize": 1
}
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=8)) as session:
# Search for food
async with session.get(search_url, params=search_params) as response:
if response.status != 200:
raise HTTPException(status_code=404, detail=f"Food '{food_name}' not found")
search_data = await response.json()
if not search_data.get('foods'):
raise HTTPException(status_code=404, detail=f"Food '{food_name}' not found")
# Get first food item
food_item = search_data['foods'][0]
fdc_id = food_item['fdcId']
# Get detailed nutrition information
detail_url = f"{USDA_CONFIG['baseUrl']}{USDA_CONFIG['detailEndpoint']}/{fdc_id}"
detail_params = {"api_key": USDA_CONFIG['key']}
async with session.get(detail_url, params=detail_params) as detail_response:
if detail_response.status != 200:
raise HTTPException(status_code=404, detail="Food details not found")
detail_data = await detail_response.json()
# Extract nutrients
nutrients = extract_comprehensive_nutrients(detail_data.get('foodNutrients', []))
return nutrients
except asyncio.TimeoutError:
raise HTTPException(status_code=408, detail="USDA API timeout")
except Exception as e:
logger.error(f"Error getting food nutrients for {food_name}: {e}")
# Return mock nutrients as fallback
return get_mock_comprehensive_nutrients()
def extract_comprehensive_nutrients(food_nutrients: List) -> Dict:
"""Extract comprehensive nutrients from USDA response including all specified nutrients"""
# Enhanced nutrient mapping with USDA nutrient IDs
nutrient_map = {
# Basic macronutrients
'Fat': [1004], # Total lipid (fat)
'Carbohydrates': [1005], # Carbohydrate, by difference
'Protein': [1003], # Protein
# Vitamins with proper units
'Vitamin_C_mg': [1162], # Vitamin C, total ascorbic acid (mg)
'Vitamin_D_ug': [1114], # Vitamin D (D2 + D3) (µg)
'Vitamin_B12_ug': [1178], # Vitamin B-12 (µg)
'Vitamin_B6_mg': [1175], # Vitamin B-6 (mg)
'Vitamin_A_ug': [1106, 1104], # Vitamin A, RAE (µg) or IU converted
'Vitamin_E_mg': [1109], # Vitamin E (alpha-tocopherol) (mg)
'Vitamin_K_ug': [1185], # Vitamin K (phylloquinone) (µg)
'Folate_ug': [1177, 1186], # Folate, DFE or total (µg)
# Minerals
'Calcium': [1087], # Calcium, Ca
'Iron': [1089], # Iron, Fe
'Magnesium': [1090], # Magnesium, Mg
'Potassium': [1092], # Potassium, K
'Sodium': [1093], # Sodium, Na
'Zinc': [1095], # Zinc, Zn
# Fat breakdown
'Saturated_Fat_g': [1258], # Fatty acids, total saturated (g)
'Monounsaturated_Fat_g': [1292], # Fatty acids, total monounsaturated (g)
'Polyunsaturated_Fat_g': [1293], # Fatty acids, total polyunsaturated (g)
'Cholesterol_mg': [1253], # Cholesterol (mg)
}
nutrients = {}
for nutrient_name, nutrient_ids in nutrient_map.items():
value = 0.0
for food_nutrient in food_nutrients:
if food_nutrient.get('nutrient', {}).get('id') in nutrient_ids:
amount = food_nutrient.get('amount', 0.0)
# Handle potential None values
if amount is not None:
value = amount
break
nutrients[nutrient_name] = value
return nutrients
def get_mock_comprehensive_nutrients() -> Dict:
"""Mock comprehensive food nutrients for testing"""
return {
# Basic macronutrients
'Fat': np.random.uniform(0, 15),
'Carbohydrates': np.random.uniform(0, 40),
'Protein': np.random.uniform(0, 20),
# Vitamins with realistic ranges
'Vitamin_C_mg': np.random.uniform(0, 200), # 0-200 mg
'Vitamin_D_ug': np.random.uniform(0, 25), # 0-25 µg
'Vitamin_B12_ug': np.random.uniform(0, 10), # 0-10 µg
'Vitamin_B6_mg': np.random.uniform(0, 5), # 0-5 mg
'Vitamin_A_ug': np.random.uniform(0, 1500), # 0-1500 µg
'Vitamin_E_mg': np.random.uniform(0, 30), # 0-30 mg
'Vitamin_K_ug': np.random.uniform(0, 500), # 0-500 µg
'Folate_ug': np.random.uniform(0, 300), # 0-300 µg
# Minerals
'Calcium': np.random.uniform(0, 250),
'Iron': np.random.uniform(0, 15),
'Magnesium': np.random.uniform(0, 150),
'Potassium': np.random.uniform(0, 800),
'Sodium': np.random.uniform(0, 200),
'Zinc': np.random.uniform(0, 8),
# Fat breakdown
'Saturated_Fat_g': np.random.uniform(0, 10), # 0-10 g
'Monounsaturated_Fat_g': np.random.uniform(0, 8), # 0-8 g
'Polyunsaturated_Fat_g': np.random.uniform(0, 6), # 0-6 g
'Cholesterol_mg': np.random.uniform(0, 300), # 0-300 mg
}
def predict_interaction(drug_descriptors: Dict, food_nutrients: Dict) -> Dict:
"""Predict drug-food interaction using loaded models"""
try:
# Combine features in the correct order
if 'feature_order' in loaded_models:
feature_order = loaded_models['feature_order']
else:
# Default feature order if not available
feature_order = list(drug_descriptors.keys()) + list(food_nutrients.keys())
# Create feature vector
features = []
for feature_name in feature_order:
if feature_name in drug_descriptors:
features.append(drug_descriptors[feature_name])
elif feature_name in food_nutrients:
features.append(food_nutrients[feature_name])
else:
features.append(0.0) # Default value for missing features
feature_array = np.array(features).reshape(1, -1)
# Make prediction
if 'xgb_model' in loaded_models and 'label_encoder' in loaded_models:
model = loaded_models['xgb_model']
label_encoder = loaded_models['label_encoder']
# Get prediction probabilities
probabilities = model.predict_proba(feature_array)[0]
predicted_class_idx = np.argmax(probabilities)
confidence = probabilities[predicted_class_idx]
# Decode prediction
predicted_effect = label_encoder.inverse_transform([predicted_class_idx])[0]
return {
'effect': predicted_effect,
'confidence': float(confidence),
'explanation': get_explanation(predicted_effect, confidence)
}
else:
# Fallback prediction logic
return get_fallback_prediction(drug_descriptors, food_nutrients)
except Exception as e:
logger.error(f"Error making prediction: {e}")
return get_fallback_prediction(drug_descriptors, food_nutrients)
def get_fallback_prediction(drug_descriptors: Dict, food_nutrients: Dict) -> Dict:
"""Enhanced fallback prediction logic considering specific nutrients"""
# Enhanced rule-based prediction considering important nutrients
effect = 'no effect'
confidence = 0.75
# Check for potential vitamin K interactions (common with anticoagulants)
if food_nutrients.get('Vitamin_K_ug', 0) > 100: # High vitamin K
effect = 'possible'
confidence = 0.68
# Check for high calcium (may affect absorption)
if food_nutrients.get('Calcium', 0) > 150:
effect = 'possible' if effect == 'no effect' else effect
confidence = max(0.65, confidence)
# Random variation for demonstration
effects = ['no effect', 'possible', 'positive', 'harmful']
if np.random.random() > 0.7: # 30% chance of different prediction
effect = np.random.choice(effects)
confidence = np.random.uniform(0.6, 0.92)
return {
'effect': effect,
'confidence': confidence,
'explanation': get_explanation(effect, confidence)
}
def get_explanation(effect: str, confidence: float) -> str:
"""Generate explanation for the prediction"""
explanations = {
'harmful': f"Significant interaction detected (confidence: {confidence:.2f}). This food may interfere with drug efficacy or cause adverse effects. Consult your healthcare provider immediately.",
'negative': f"Minor negative interaction possible (confidence: {confidence:.2f}). The food may slightly reduce drug effectiveness or absorption.",
'no effect': f"No significant interaction expected (confidence: {confidence:.2f}). The food is unlikely to affect drug absorption or metabolism significantly.",
'positive': f"Beneficial interaction detected (confidence: {confidence:.2f}). This food may enhance drug absorption, stability, or therapeutic effects.",
'possible': f"Potential interaction identified (confidence: {confidence:.2f}). Monitor for changes in drug effectiveness or side effects."
}
return explanations.get(effect, f"Interaction analysis completed with {confidence:.2f} confidence.")
# API Endpoints
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Drug-Food Interaction Prediction API",
"status": "active",
"version": "2.0.0",
"features": "Enhanced with comprehensive nutritional analysis"
}
@app.post("/api/drug/canonical")
async def get_drug_canonical(request: DrugRequest):
"""Get canonical SMILES for a drug"""
smiles = await get_canonical_smiles(request.drug_name)
return {"drug_name": request.drug_name, "canonical_smiles": smiles}
@app.post("/api/drug/descriptors")
async def get_drug_descriptors(request: DrugRequest):
"""Get molecular descriptors for a drug"""
# First get canonical SMILES
smiles = await get_canonical_smiles(request.drug_name)
# Calculate descriptors
descriptors = calculate_molecular_descriptors(smiles)
return {
"drug_name": request.drug_name,
"canonical_smiles": smiles,
"descriptors": descriptors
}
@app.post("/api/food/nutrients")
async def get_food_nutrients_endpoint(request: FoodRequest):
"""Get comprehensive nutritional information for a food"""
nutrients = await search_food_nutrients(request.food_name)
return {
"food_name": request.food_name,
"nutrients": nutrients
}
@app.post("/api/predict")
async def predict_drug_food_interaction(request: InteractionRequest):
"""Predict drug-food interaction using comprehensive nutritional data"""
try:
# Get drug information
smiles = await get_canonical_smiles(request.drug_name)
drug_descriptors = calculate_molecular_descriptors(smiles)
# Get comprehensive food nutrients
food_nutrients = await search_food_nutrients(request.food_name)
# Make prediction
prediction = predict_interaction(drug_descriptors, food_nutrients)
# Prepare response with all nutrient data
result = InteractionResult(
effect=prediction['effect'],
confidence=prediction['confidence'],
explanation=prediction['explanation'],
drug_properties=MolecularDescriptors(
MolWt=drug_descriptors['MolWt'],
LogP=drug_descriptors['LogP'],
HBA=drug_descriptors['HBA'],
HBD=drug_descriptors['HBD'],
TPSA=drug_descriptors['TPSA'],
RotBonds=drug_descriptors['RotBonds'],
RingCount=drug_descriptors['RingCount'],
FractionCSP3=drug_descriptors['FractionCSP3'],
BalabanJ=drug_descriptors['BalabanJ'],
BertzCT=drug_descriptors['BertzCT']
),
food_nutrients=FoodNutrients(**food_nutrients)
)
return result
except Exception as e:
logger.error(f"Error in prediction endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
models_loaded = all(key in loaded_models for key in ['xgb_model', 'label_encoder'])
return {
"status": "healthy",
"models_loaded": models_loaded,
"version": "2.0.0",
"features": "Comprehensive nutritional analysis with vitamins, minerals, and fat breakdown"
}
@app.get("/api/nutrients/list")
async def list_supported_nutrients():
"""List all supported nutrient features"""
nutrients = {
"macronutrients": ["Fat", "Carbohydrates", "Protein"],
"vitamins": [
"Vitamin_C_mg", "Vitamin_D_ug", "Vitamin_B12_ug", "Vitamin_B6_mg",
"Vitamin_A_ug", "Vitamin_E_mg", "Vitamin_K_ug", "Folate_ug"
],
"minerals": ["Calcium", "Iron", "Magnesium", "Potassium", "Sodium", "Zinc"],
"fat_breakdown": [
"Saturated_Fat_g", "Monounsaturated_Fat_g",
"Polyunsaturated_Fat_g", "Cholesterol_mg"
]
}
return {
"total_nutrients": sum(len(v) for v in nutrients.values()),
"categories": nutrients
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)