import os import numpy as np import onnxruntime as rt from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from typing import Dict, Any from huggingface_hub import hf_hub_download from dotenv import load_dotenv load_dotenv() app = FastAPI(title="Digital Doctors Assistant ML API") # Model configurations MODELS = { 'risk_assessment': { 'filename': 'risk_assessment.onnx', 'features': ['age', 'bmi', 'systolic_bp', 'diastolic_bp', 'chronic_conditions_count', 'severity_score'], 'output_classes': ['Low', 'Medium', 'High'] }, 'treatment_outcome': { 'filename': 'treatment_outcome.onnx', 'features': ['patient_age', 'severity_score', 'compliance_rate', 'medication_encoded', 'condition_encoded'], 'output_classes': ['No Success', 'Success'] } } # Load models on startup risk_session = None treatment_session = None @app.on_event("startup") async def load_models(): global risk_session, treatment_session # Get token from environment (set as Space secret) token = os.getenv("HUGGINGFACE_TOKEN") # Download and load risk assessment model risk_path = hf_hub_download( repo_id="Tegaconsult/digital-doctors-assistant-ml", filename="risk_assessment.onnx", token=token ) risk_session = rt.InferenceSession(risk_path) # Download and load treatment outcome model treatment_path = hf_hub_download( repo_id="Tegaconsult/digital-doctors-assistant-ml", filename="treatment_outcome.onnx", token=token ) treatment_session = rt.InferenceSession(treatment_path) print("Models loaded successfully!") class RiskAssessmentRequest(BaseModel): age: float bmi: float systolic_bp: float diastolic_bp: float chronic_conditions: str = "" severity_score: float class TreatmentOutcomeRequest(BaseModel): patient_age: float severity_score: float compliance_rate: float medication: str condition: str @app.get("/", response_class=HTMLResponse) def root(): html_content = """ Digital Doctors Assistant ML

Digital Doctors Assistant ML

Risk Assessment

Treatment Outcome

""" return html_content @app.post("/predict/risk") def predict_risk(request: RiskAssessmentRequest): """Predict patient risk level""" try: # Prepare input chronic_count = len(request.chronic_conditions.split(',')) if request.chronic_conditions else 0 input_data = np.array([[ request.age, request.bmi, request.systolic_bp, request.diastolic_bp, chronic_count, request.severity_score ]], dtype=np.float32) # Run inference input_name = risk_session.get_inputs()[0].name result = risk_session.run(None, {input_name: input_data}) # Parse results prediction = result[0][0] probabilities = result[1][0] if len(result) > 1 else None output_classes = MODELS['risk_assessment']['output_classes'] if isinstance(prediction, (int, np.integer)): prediction_label = output_classes[prediction] else: prediction_label = prediction confidence = float(max(probabilities)) if probabilities is not None else 0.0 return { 'success': True, 'model': 'risk_assessment', 'prediction': prediction_label, 'confidence': confidence, 'probabilities': { output_classes[i]: float(probabilities[i]) for i in range(len(output_classes)) } if probabilities is not None else None } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict/treatment") def predict_treatment(request: TreatmentOutcomeRequest): """Predict treatment outcome""" try: # Encode categorical variables medication_mapping = { 'Paracetamol': 0, 'Ibuprofen': 1, 'Amoxicillin': 2, 'Ciprofloxacin': 3, 'Metformin': 4, 'Lisinopril': 5, 'Amlodipine': 6, 'Omeprazole': 7 } condition_mapping = { 'Common Cold': 0, 'Influenza': 1, 'Pneumonia': 2, 'Bronchitis': 3, 'Hypertension': 4, 'Diabetes Type 2': 5, 'Migraine': 6, 'Gastroenteritis': 7 } # Prepare input input_data = np.array([[ request.patient_age, request.severity_score, request.compliance_rate, medication_mapping.get(request.medication, 0), condition_mapping.get(request.condition, 0) ]], dtype=np.float32) # Run inference input_name = treatment_session.get_inputs()[0].name result = treatment_session.run(None, {input_name: input_data}) # Parse results prediction = result[0][0] probabilities = result[1][0] if len(result) > 1 else None success_probability = float(probabilities[1]) if probabilities is not None else 0.5 return { 'success': True, 'model': 'treatment_outcome', 'prediction': int(prediction), 'success_probability': round(success_probability * 100, 1), 'confidence': float(max(probabilities)) if probabilities is not None else 0.0, 'probabilities': { 'failure': float(probabilities[0]), 'success': float(probabilities[1]) } if probabilities is not None else None } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") def health_check(): return { "status": "healthy", "models_loaded": { "risk_assessment": risk_session is not None, "treatment_outcome": treatment_session is not None } }