import os
import numpy as np
import onnxruntime as rt
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import Dict, Any
from huggingface_hub import hf_hub_download
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(title="Digital Doctors Assistant ML API")
# Model configurations
MODELS = {
'risk_assessment': {
'filename': 'risk_assessment.onnx',
'features': ['age', 'bmi', 'systolic_bp', 'diastolic_bp',
'chronic_conditions_count', 'severity_score'],
'output_classes': ['Low', 'Medium', 'High']
},
'treatment_outcome': {
'filename': 'treatment_outcome.onnx',
'features': ['patient_age', 'severity_score', 'compliance_rate',
'medication_encoded', 'condition_encoded'],
'output_classes': ['No Success', 'Success']
}
}
# Load models on startup
risk_session = None
treatment_session = None
@app.on_event("startup")
async def load_models():
global risk_session, treatment_session
# Get token from environment (set as Space secret)
token = os.getenv("HUGGINGFACE_TOKEN")
# Download and load risk assessment model
risk_path = hf_hub_download(
repo_id="Tegaconsult/digital-doctors-assistant-ml",
filename="risk_assessment.onnx",
token=token
)
risk_session = rt.InferenceSession(risk_path)
# Download and load treatment outcome model
treatment_path = hf_hub_download(
repo_id="Tegaconsult/digital-doctors-assistant-ml",
filename="treatment_outcome.onnx",
token=token
)
treatment_session = rt.InferenceSession(treatment_path)
print("Models loaded successfully!")
class RiskAssessmentRequest(BaseModel):
age: float
bmi: float
systolic_bp: float
diastolic_bp: float
chronic_conditions: str = ""
severity_score: float
class TreatmentOutcomeRequest(BaseModel):
patient_age: float
severity_score: float
compliance_rate: float
medication: str
condition: str
@app.get("/", response_class=HTMLResponse)
def root():
html_content = """
Digital Doctors Assistant ML
Digital Doctors Assistant ML
"""
return html_content
@app.post("/predict/risk")
def predict_risk(request: RiskAssessmentRequest):
"""Predict patient risk level"""
try:
# Prepare input
chronic_count = len(request.chronic_conditions.split(',')) if request.chronic_conditions else 0
input_data = np.array([[
request.age,
request.bmi,
request.systolic_bp,
request.diastolic_bp,
chronic_count,
request.severity_score
]], dtype=np.float32)
# Run inference
input_name = risk_session.get_inputs()[0].name
result = risk_session.run(None, {input_name: input_data})
# Parse results
prediction = result[0][0]
probabilities = result[1][0] if len(result) > 1 else None
output_classes = MODELS['risk_assessment']['output_classes']
if isinstance(prediction, (int, np.integer)):
prediction_label = output_classes[prediction]
else:
prediction_label = prediction
confidence = float(max(probabilities)) if probabilities is not None else 0.0
return {
'success': True,
'model': 'risk_assessment',
'prediction': prediction_label,
'confidence': confidence,
'probabilities': {
output_classes[i]: float(probabilities[i])
for i in range(len(output_classes))
} if probabilities is not None else None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict/treatment")
def predict_treatment(request: TreatmentOutcomeRequest):
"""Predict treatment outcome"""
try:
# Encode categorical variables
medication_mapping = {
'Paracetamol': 0, 'Ibuprofen': 1, 'Amoxicillin': 2, 'Ciprofloxacin': 3,
'Metformin': 4, 'Lisinopril': 5, 'Amlodipine': 6, 'Omeprazole': 7
}
condition_mapping = {
'Common Cold': 0, 'Influenza': 1, 'Pneumonia': 2, 'Bronchitis': 3,
'Hypertension': 4, 'Diabetes Type 2': 5, 'Migraine': 6, 'Gastroenteritis': 7
}
# Prepare input
input_data = np.array([[
request.patient_age,
request.severity_score,
request.compliance_rate,
medication_mapping.get(request.medication, 0),
condition_mapping.get(request.condition, 0)
]], dtype=np.float32)
# Run inference
input_name = treatment_session.get_inputs()[0].name
result = treatment_session.run(None, {input_name: input_data})
# Parse results
prediction = result[0][0]
probabilities = result[1][0] if len(result) > 1 else None
success_probability = float(probabilities[1]) if probabilities is not None else 0.5
return {
'success': True,
'model': 'treatment_outcome',
'prediction': int(prediction),
'success_probability': round(success_probability * 100, 1),
'confidence': float(max(probabilities)) if probabilities is not None else 0.0,
'probabilities': {
'failure': float(probabilities[0]),
'success': float(probabilities[1])
} if probabilities is not None else None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health_check():
return {
"status": "healthy",
"models_loaded": {
"risk_assessment": risk_session is not None,
"treatment_outcome": treatment_session is not None
}
}