| |
|
| | import os
|
| | import numpy as np
|
| | import onnxruntime as rt
|
| | from fastapi import FastAPI, HTTPException
|
| | from fastapi.responses import HTMLResponse
|
| | from fastapi.staticfiles import StaticFiles
|
| | from pydantic import BaseModel
|
| | from typing import Dict, Any
|
| | from huggingface_hub import hf_hub_download
|
| | from dotenv import load_dotenv
|
| |
|
| | load_dotenv()
|
| |
|
| | app = FastAPI(title="Digital Doctors Assistant ML API")
|
| |
|
| |
|
| | MODELS = {
|
| | 'risk_assessment': {
|
| | 'filename': 'risk_assessment.onnx',
|
| | 'features': ['age', 'bmi', 'systolic_bp', 'diastolic_bp',
|
| | 'chronic_conditions_count', 'severity_score'],
|
| | 'output_classes': ['Low', 'Medium', 'High']
|
| | },
|
| | 'treatment_outcome': {
|
| | 'filename': 'treatment_outcome.onnx',
|
| | 'features': ['patient_age', 'severity_score', 'compliance_rate',
|
| | 'medication_encoded', 'condition_encoded'],
|
| | 'output_classes': ['No Success', 'Success']
|
| | }
|
| | }
|
| |
|
| |
|
| | risk_session = None
|
| | treatment_session = None
|
| |
|
| | @app.on_event("startup")
|
| | async def load_models():
|
| | global risk_session, treatment_session
|
| |
|
| |
|
| | token = os.getenv("HUGGINGFACE_TOKEN")
|
| |
|
| |
|
| | risk_path = hf_hub_download(
|
| | repo_id="Tegaconsult/digital-doctors-assistant-ml",
|
| | filename="risk_assessment.onnx",
|
| | token=token
|
| | )
|
| | risk_session = rt.InferenceSession(risk_path)
|
| |
|
| |
|
| | treatment_path = hf_hub_download(
|
| | repo_id="Tegaconsult/digital-doctors-assistant-ml",
|
| | filename="treatment_outcome.onnx",
|
| | token=token
|
| | )
|
| | treatment_session = rt.InferenceSession(treatment_path)
|
| |
|
| | print("Models loaded successfully!")
|
| |
|
| | class RiskAssessmentRequest(BaseModel):
|
| | age: float
|
| | bmi: float
|
| | systolic_bp: float
|
| | diastolic_bp: float
|
| | chronic_conditions: str = ""
|
| | severity_score: float
|
| |
|
| | class TreatmentOutcomeRequest(BaseModel):
|
| | patient_age: float
|
| | severity_score: float
|
| | compliance_rate: float
|
| | medication: str
|
| | condition: str
|
| |
|
| | @app.get("/", response_class=HTMLResponse)
|
| | def root():
|
| | html_content = """<!DOCTYPE html>
|
| | <html lang="en">
|
| | <head>
|
| | <meta charset="UTF-8">
|
| | <meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| | <title>Digital Doctors Assistant ML</title>
|
| | <style>
|
| | * { margin: 0; padding: 0; box-sizing: border-box; }
|
| | body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; padding: 20px; }
|
| | .container { max-width: 1200px; margin: 0 auto; }
|
| | h1 { color: white; text-align: center; margin-bottom: 30px; font-size: 2.5em; }
|
| | .cards { display: grid; grid-template-columns: repeat(auto-fit, minmax(500px, 1fr)); gap: 20px; }
|
| | .card { background: white; border-radius: 15px; padding: 30px; box-shadow: 0 10px 30px rgba(0,0,0,0.2); }
|
| | .card h2 { color: #667eea; margin-bottom: 20px; font-size: 1.8em; }
|
| | .form-group { margin-bottom: 15px; }
|
| | label { display: block; margin-bottom: 5px; color: #333; font-weight: 600; }
|
| | input, textarea { width: 100%; padding: 10px; border: 2px solid #e0e0e0; border-radius: 8px; font-size: 14px; transition: border 0.3s; }
|
| | input:focus, textarea:focus { outline: none; border-color: #667eea; }
|
| | button { width: 100%; padding: 12px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border: none; border-radius: 8px; font-size: 16px; font-weight: 600; cursor: pointer; transition: transform 0.2s; }
|
| | button:hover { transform: translateY(-2px); }
|
| | button:active { transform: translateY(0); }
|
| | .result { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; border-left: 4px solid #667eea; }
|
| | .result h3 { color: #667eea; margin-bottom: 10px; }
|
| | .result-item { margin: 8px 0; color: #555; }
|
| | .result-item strong { color: #333; }
|
| | .error { background: #fee; border-left-color: #f44; }
|
| | .error h3 { color: #f44; }
|
| | .hidden { display: none; }
|
| | .risk-low { color: #28a745; font-weight: bold; }
|
| | .risk-medium { color: #ffc107; font-weight: bold; }
|
| | .risk-high { color: #dc3545; font-weight: bold; }
|
| | </style>
|
| | </head>
|
| | <body>
|
| | <div class="container">
|
| | <h1>Digital Doctors Assistant ML</h1>
|
| |
|
| | <div class="cards">
|
| | <div class="card">
|
| | <h2>Risk Assessment</h2>
|
| | <form id="riskForm">
|
| | <div class="form-group">
|
| | <label>Age</label>
|
| | <input type="number" id="age" required min="0" max="120" value="45">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>BMI</label>
|
| | <input type="number" id="bmi" required step="0.1" min="10" max="50" value="28.5">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Systolic BP</label>
|
| | <input type="number" id="systolic_bp" required min="70" max="200" value="140">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Diastolic BP</label>
|
| | <input type="number" id="diastolic_bp" required min="40" max="130" value="90">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Chronic Conditions (comma-separated)</label>
|
| | <input type="text" id="chronic_conditions" placeholder="e.g., diabetes,hypertension" value="diabetes,hypertension">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Severity Score (0-10)</label>
|
| | <input type="number" id="severity_score" required step="0.1" min="0" max="10" value="7.5">
|
| | </div>
|
| | <button type="submit">Predict Risk</button>
|
| | </form>
|
| | <div id="riskResult" class="result hidden"></div>
|
| | </div>
|
| |
|
| | <div class="card">
|
| | <h2>Treatment Outcome</h2>
|
| | <form id="treatmentForm">
|
| | <div class="form-group">
|
| | <label>Patient Age</label>
|
| | <input type="number" id="patient_age" required min="0" max="120" value="55">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Severity Score (0-10)</label>
|
| | <input type="number" id="treatment_severity" required step="0.1" min="0" max="10" value="6.5">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Compliance Rate (0-1)</label>
|
| | <input type="number" id="compliance_rate" required step="0.01" min="0" max="1" value="0.85">
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Medication</label>
|
| | <input type="text" id="medication" required list="medications" value="Metformin">
|
| | <datalist id="medications">
|
| | <option value="Paracetamol">
|
| | <option value="Ibuprofen">
|
| | <option value="Amoxicillin">
|
| | <option value="Ciprofloxacin">
|
| | <option value="Metformin">
|
| | <option value="Lisinopril">
|
| | <option value="Amlodipine">
|
| | <option value="Omeprazole">
|
| | </datalist>
|
| | </div>
|
| | <div class="form-group">
|
| | <label>Condition</label>
|
| | <input type="text" id="condition" required list="conditions" value="Diabetes Type 2">
|
| | <datalist id="conditions">
|
| | <option value="Common Cold">
|
| | <option value="Influenza">
|
| | <option value="Pneumonia">
|
| | <option value="Bronchitis">
|
| | <option value="Hypertension">
|
| | <option value="Diabetes Type 2">
|
| | <option value="Migraine">
|
| | <option value="Gastroenteritis">
|
| | </datalist>
|
| | </div>
|
| | <button type="submit">Predict Outcome</button>
|
| | </form>
|
| | <div id="treatmentResult" class="result hidden"></div>
|
| | </div>
|
| | </div>
|
| | </div>
|
| |
|
| | <script>
|
| | document.getElementById('riskForm').addEventListener('submit', async (e) => {
|
| | e.preventDefault();
|
| | const resultDiv = document.getElementById('riskResult');
|
| |
|
| | const data = {
|
| | age: parseFloat(document.getElementById('age').value),
|
| | bmi: parseFloat(document.getElementById('bmi').value),
|
| | systolic_bp: parseFloat(document.getElementById('systolic_bp').value),
|
| | diastolic_bp: parseFloat(document.getElementById('diastolic_bp').value),
|
| | chronic_conditions: document.getElementById('chronic_conditions').value,
|
| | severity_score: parseFloat(document.getElementById('severity_score').value)
|
| | };
|
| |
|
| | try {
|
| | const response = await fetch('/predict/risk', {
|
| | method: 'POST',
|
| | headers: { 'Content-Type': 'application/json' },
|
| | body: JSON.stringify(data)
|
| | });
|
| |
|
| | const result = await response.json();
|
| |
|
| | if (result.success) {
|
| | const riskClass = result.prediction.toLowerCase().replace(' ', '-');
|
| | resultDiv.className = 'result';
|
| | resultDiv.innerHTML = `
|
| | <h3>Prediction Results</h3>
|
| | <div class="result-item"><strong>Risk Level:</strong> <span class="risk-${riskClass}">${result.prediction}</span></div>
|
| | <div class="result-item"><strong>Confidence:</strong> ${(result.confidence * 100).toFixed(1)}%</div>
|
| | ${result.probabilities ? `
|
| | <div class="result-item"><strong>Probabilities:</strong></div>
|
| | <div class="result-item">Low: ${(result.probabilities.Low * 100).toFixed(1)}%</div>
|
| | <div class="result-item">Medium: ${(result.probabilities.Medium * 100).toFixed(1)}%</div>
|
| | <div class="result-item">High: ${(result.probabilities.High * 100).toFixed(1)}%</div>
|
| | ` : ''}
|
| | `;
|
| | } else {
|
| | throw new Error('Prediction failed');
|
| | }
|
| | } catch (error) {
|
| | resultDiv.className = 'result error';
|
| | resultDiv.innerHTML = `<h3>Error</h3><div class="result-item">${error.message}</div>`;
|
| | }
|
| |
|
| | resultDiv.classList.remove('hidden');
|
| | });
|
| |
|
| | document.getElementById('treatmentForm').addEventListener('submit', async (e) => {
|
| | e.preventDefault();
|
| | const resultDiv = document.getElementById('treatmentResult');
|
| |
|
| | const data = {
|
| | patient_age: parseFloat(document.getElementById('patient_age').value),
|
| | severity_score: parseFloat(document.getElementById('treatment_severity').value),
|
| | compliance_rate: parseFloat(document.getElementById('compliance_rate').value),
|
| | medication: document.getElementById('medication').value,
|
| | condition: document.getElementById('condition').value
|
| | };
|
| |
|
| | try {
|
| | const response = await fetch('/predict/treatment', {
|
| | method: 'POST',
|
| | headers: { 'Content-Type': 'application/json' },
|
| | body: JSON.stringify(data)
|
| | });
|
| |
|
| | const result = await response.json();
|
| |
|
| | if (result.success) {
|
| | resultDiv.className = 'result';
|
| | resultDiv.innerHTML = `
|
| | <h3>Prediction Results</h3>
|
| | <div class="result-item"><strong>Outcome:</strong> ${result.prediction === 1 ? 'Success' : 'No Success'}</div>
|
| | <div class="result-item"><strong>Success Probability:</strong> ${result.success_probability}%</div>
|
| | <div class="result-item"><strong>Confidence:</strong> ${(result.confidence * 100).toFixed(1)}%</div>
|
| | ${result.probabilities ? `
|
| | <div class="result-item"><strong>Probabilities:</strong></div>
|
| | <div class="result-item">Failure: ${(result.probabilities.failure * 100).toFixed(1)}%</div>
|
| | <div class="result-item">Success: ${(result.probabilities.success * 100).toFixed(1)}%</div>
|
| | ` : ''}
|
| | `;
|
| | } else {
|
| | throw new Error('Prediction failed');
|
| | }
|
| | } catch (error) {
|
| | resultDiv.className = 'result error';
|
| | resultDiv.innerHTML = `<h3>Error</h3><div class="result-item">${error.message}</div>`;
|
| | }
|
| |
|
| | resultDiv.classList.remove('hidden');
|
| | });
|
| | </script>
|
| | </body>
|
| | </html>"""
|
| | return html_content
|
| |
|
| | @app.post("/predict/risk")
|
| | def predict_risk(request: RiskAssessmentRequest):
|
| | """Predict patient risk level"""
|
| | try:
|
| |
|
| | chronic_count = len(request.chronic_conditions.split(',')) if request.chronic_conditions else 0
|
| | input_data = np.array([[
|
| | request.age,
|
| | request.bmi,
|
| | request.systolic_bp,
|
| | request.diastolic_bp,
|
| | chronic_count,
|
| | request.severity_score
|
| | ]], dtype=np.float32)
|
| |
|
| |
|
| | input_name = risk_session.get_inputs()[0].name
|
| | result = risk_session.run(None, {input_name: input_data})
|
| |
|
| |
|
| | prediction = result[0][0]
|
| | probabilities = result[1][0] if len(result) > 1 else None
|
| |
|
| | output_classes = MODELS['risk_assessment']['output_classes']
|
| | if isinstance(prediction, (int, np.integer)):
|
| | prediction_label = output_classes[prediction]
|
| | else:
|
| | prediction_label = prediction
|
| |
|
| | confidence = float(max(probabilities)) if probabilities is not None else 0.0
|
| |
|
| | return {
|
| | 'success': True,
|
| | 'model': 'risk_assessment',
|
| | 'prediction': prediction_label,
|
| | 'confidence': confidence,
|
| | 'probabilities': {
|
| | output_classes[i]: float(probabilities[i])
|
| | for i in range(len(output_classes))
|
| | } if probabilities is not None else None
|
| | }
|
| |
|
| | except Exception as e:
|
| | raise HTTPException(status_code=500, detail=str(e))
|
| |
|
| | @app.post("/predict/treatment")
|
| | def predict_treatment(request: TreatmentOutcomeRequest):
|
| | """Predict treatment outcome"""
|
| | try:
|
| |
|
| | medication_mapping = {
|
| | 'Paracetamol': 0, 'Ibuprofen': 1, 'Amoxicillin': 2, 'Ciprofloxacin': 3,
|
| | 'Metformin': 4, 'Lisinopril': 5, 'Amlodipine': 6, 'Omeprazole': 7
|
| | }
|
| |
|
| | condition_mapping = {
|
| | 'Common Cold': 0, 'Influenza': 1, 'Pneumonia': 2, 'Bronchitis': 3,
|
| | 'Hypertension': 4, 'Diabetes Type 2': 5, 'Migraine': 6, 'Gastroenteritis': 7
|
| | }
|
| |
|
| |
|
| | input_data = np.array([[
|
| | request.patient_age,
|
| | request.severity_score,
|
| | request.compliance_rate,
|
| | medication_mapping.get(request.medication, 0),
|
| | condition_mapping.get(request.condition, 0)
|
| | ]], dtype=np.float32)
|
| |
|
| |
|
| | input_name = treatment_session.get_inputs()[0].name
|
| | result = treatment_session.run(None, {input_name: input_data})
|
| |
|
| |
|
| | prediction = result[0][0]
|
| | probabilities = result[1][0] if len(result) > 1 else None
|
| |
|
| | success_probability = float(probabilities[1]) if probabilities is not None else 0.5
|
| |
|
| | return {
|
| | 'success': True,
|
| | 'model': 'treatment_outcome',
|
| | 'prediction': int(prediction),
|
| | 'success_probability': round(success_probability * 100, 1),
|
| | 'confidence': float(max(probabilities)) if probabilities is not None else 0.0,
|
| | 'probabilities': {
|
| | 'failure': float(probabilities[0]),
|
| | 'success': float(probabilities[1])
|
| | } if probabilities is not None else None
|
| | }
|
| |
|
| | except Exception as e:
|
| | raise HTTPException(status_code=500, detail=str(e))
|
| |
|
| | @app.get("/health")
|
| | def health_check():
|
| | return {
|
| | "status": "healthy",
|
| | "models_loaded": {
|
| | "risk_assessment": risk_session is not None,
|
| | "treatment_outcome": treatment_session is not None
|
| | }
|
| | }
|
| |
|