Tegaconsult's picture
Update ml.py
9a647e6 verified
import os
import numpy as np
import onnxruntime as rt
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import Dict, Any
from huggingface_hub import hf_hub_download
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(title="Digital Doctors Assistant ML API")
# Model configurations
MODELS = {
'risk_assessment': {
'filename': 'risk_assessment.onnx',
'features': ['age', 'bmi', 'systolic_bp', 'diastolic_bp',
'chronic_conditions_count', 'severity_score'],
'output_classes': ['Low', 'Medium', 'High']
},
'treatment_outcome': {
'filename': 'treatment_outcome.onnx',
'features': ['patient_age', 'severity_score', 'compliance_rate',
'medication_encoded', 'condition_encoded'],
'output_classes': ['No Success', 'Success']
}
}
# Load models on startup
risk_session = None
treatment_session = None
@app.on_event("startup")
async def load_models():
global risk_session, treatment_session
# Get token from environment (set as Space secret)
token = os.getenv("HUGGINGFACE_TOKEN")
# Download and load risk assessment model
risk_path = hf_hub_download(
repo_id="Tegaconsult/digital-doctors-assistant-ml",
filename="risk_assessment.onnx",
token=token
)
risk_session = rt.InferenceSession(risk_path)
# Download and load treatment outcome model
treatment_path = hf_hub_download(
repo_id="Tegaconsult/digital-doctors-assistant-ml",
filename="treatment_outcome.onnx",
token=token
)
treatment_session = rt.InferenceSession(treatment_path)
print("Models loaded successfully!")
class RiskAssessmentRequest(BaseModel):
age: float
bmi: float
systolic_bp: float
diastolic_bp: float
chronic_conditions: str = ""
severity_score: float
class TreatmentOutcomeRequest(BaseModel):
patient_age: float
severity_score: float
compliance_rate: float
medication: str
condition: str
@app.get("/", response_class=HTMLResponse)
def root():
html_content = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Digital Doctors Assistant ML</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; padding: 20px; }
.container { max-width: 1200px; margin: 0 auto; }
h1 { color: white; text-align: center; margin-bottom: 30px; font-size: 2.5em; }
.cards { display: grid; grid-template-columns: repeat(auto-fit, minmax(500px, 1fr)); gap: 20px; }
.card { background: white; border-radius: 15px; padding: 30px; box-shadow: 0 10px 30px rgba(0,0,0,0.2); }
.card h2 { color: #667eea; margin-bottom: 20px; font-size: 1.8em; }
.form-group { margin-bottom: 15px; }
label { display: block; margin-bottom: 5px; color: #333; font-weight: 600; }
input, textarea { width: 100%; padding: 10px; border: 2px solid #e0e0e0; border-radius: 8px; font-size: 14px; transition: border 0.3s; }
input:focus, textarea:focus { outline: none; border-color: #667eea; }
button { width: 100%; padding: 12px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border: none; border-radius: 8px; font-size: 16px; font-weight: 600; cursor: pointer; transition: transform 0.2s; }
button:hover { transform: translateY(-2px); }
button:active { transform: translateY(0); }
.result { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; border-left: 4px solid #667eea; }
.result h3 { color: #667eea; margin-bottom: 10px; }
.result-item { margin: 8px 0; color: #555; }
.result-item strong { color: #333; }
.error { background: #fee; border-left-color: #f44; }
.error h3 { color: #f44; }
.hidden { display: none; }
.risk-low { color: #28a745; font-weight: bold; }
.risk-medium { color: #ffc107; font-weight: bold; }
.risk-high { color: #dc3545; font-weight: bold; }
</style>
</head>
<body>
<div class="container">
<h1>Digital Doctors Assistant ML</h1>
<div class="cards">
<div class="card">
<h2>Risk Assessment</h2>
<form id="riskForm">
<div class="form-group">
<label>Age</label>
<input type="number" id="age" required min="0" max="120" value="45">
</div>
<div class="form-group">
<label>BMI</label>
<input type="number" id="bmi" required step="0.1" min="10" max="50" value="28.5">
</div>
<div class="form-group">
<label>Systolic BP</label>
<input type="number" id="systolic_bp" required min="70" max="200" value="140">
</div>
<div class="form-group">
<label>Diastolic BP</label>
<input type="number" id="diastolic_bp" required min="40" max="130" value="90">
</div>
<div class="form-group">
<label>Chronic Conditions (comma-separated)</label>
<input type="text" id="chronic_conditions" placeholder="e.g., diabetes,hypertension" value="diabetes,hypertension">
</div>
<div class="form-group">
<label>Severity Score (0-10)</label>
<input type="number" id="severity_score" required step="0.1" min="0" max="10" value="7.5">
</div>
<button type="submit">Predict Risk</button>
</form>
<div id="riskResult" class="result hidden"></div>
</div>
<div class="card">
<h2>Treatment Outcome</h2>
<form id="treatmentForm">
<div class="form-group">
<label>Patient Age</label>
<input type="number" id="patient_age" required min="0" max="120" value="55">
</div>
<div class="form-group">
<label>Severity Score (0-10)</label>
<input type="number" id="treatment_severity" required step="0.1" min="0" max="10" value="6.5">
</div>
<div class="form-group">
<label>Compliance Rate (0-1)</label>
<input type="number" id="compliance_rate" required step="0.01" min="0" max="1" value="0.85">
</div>
<div class="form-group">
<label>Medication</label>
<input type="text" id="medication" required list="medications" value="Metformin">
<datalist id="medications">
<option value="Paracetamol">
<option value="Ibuprofen">
<option value="Amoxicillin">
<option value="Ciprofloxacin">
<option value="Metformin">
<option value="Lisinopril">
<option value="Amlodipine">
<option value="Omeprazole">
</datalist>
</div>
<div class="form-group">
<label>Condition</label>
<input type="text" id="condition" required list="conditions" value="Diabetes Type 2">
<datalist id="conditions">
<option value="Common Cold">
<option value="Influenza">
<option value="Pneumonia">
<option value="Bronchitis">
<option value="Hypertension">
<option value="Diabetes Type 2">
<option value="Migraine">
<option value="Gastroenteritis">
</datalist>
</div>
<button type="submit">Predict Outcome</button>
</form>
<div id="treatmentResult" class="result hidden"></div>
</div>
</div>
</div>
<script>
document.getElementById('riskForm').addEventListener('submit', async (e) => {
e.preventDefault();
const resultDiv = document.getElementById('riskResult');
const data = {
age: parseFloat(document.getElementById('age').value),
bmi: parseFloat(document.getElementById('bmi').value),
systolic_bp: parseFloat(document.getElementById('systolic_bp').value),
diastolic_bp: parseFloat(document.getElementById('diastolic_bp').value),
chronic_conditions: document.getElementById('chronic_conditions').value,
severity_score: parseFloat(document.getElementById('severity_score').value)
};
try {
const response = await fetch('/predict/risk', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(data)
});
const result = await response.json();
if (result.success) {
const riskClass = result.prediction.toLowerCase().replace(' ', '-');
resultDiv.className = 'result';
resultDiv.innerHTML = `
<h3>Prediction Results</h3>
<div class="result-item"><strong>Risk Level:</strong> <span class="risk-${riskClass}">${result.prediction}</span></div>
<div class="result-item"><strong>Confidence:</strong> ${(result.confidence * 100).toFixed(1)}%</div>
${result.probabilities ? `
<div class="result-item"><strong>Probabilities:</strong></div>
<div class="result-item">Low: ${(result.probabilities.Low * 100).toFixed(1)}%</div>
<div class="result-item">Medium: ${(result.probabilities.Medium * 100).toFixed(1)}%</div>
<div class="result-item">High: ${(result.probabilities.High * 100).toFixed(1)}%</div>
` : ''}
`;
} else {
throw new Error('Prediction failed');
}
} catch (error) {
resultDiv.className = 'result error';
resultDiv.innerHTML = `<h3>Error</h3><div class="result-item">${error.message}</div>`;
}
resultDiv.classList.remove('hidden');
});
document.getElementById('treatmentForm').addEventListener('submit', async (e) => {
e.preventDefault();
const resultDiv = document.getElementById('treatmentResult');
const data = {
patient_age: parseFloat(document.getElementById('patient_age').value),
severity_score: parseFloat(document.getElementById('treatment_severity').value),
compliance_rate: parseFloat(document.getElementById('compliance_rate').value),
medication: document.getElementById('medication').value,
condition: document.getElementById('condition').value
};
try {
const response = await fetch('/predict/treatment', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(data)
});
const result = await response.json();
if (result.success) {
resultDiv.className = 'result';
resultDiv.innerHTML = `
<h3>Prediction Results</h3>
<div class="result-item"><strong>Outcome:</strong> ${result.prediction === 1 ? 'Success' : 'No Success'}</div>
<div class="result-item"><strong>Success Probability:</strong> ${result.success_probability}%</div>
<div class="result-item"><strong>Confidence:</strong> ${(result.confidence * 100).toFixed(1)}%</div>
${result.probabilities ? `
<div class="result-item"><strong>Probabilities:</strong></div>
<div class="result-item">Failure: ${(result.probabilities.failure * 100).toFixed(1)}%</div>
<div class="result-item">Success: ${(result.probabilities.success * 100).toFixed(1)}%</div>
` : ''}
`;
} else {
throw new Error('Prediction failed');
}
} catch (error) {
resultDiv.className = 'result error';
resultDiv.innerHTML = `<h3>Error</h3><div class="result-item">${error.message}</div>`;
}
resultDiv.classList.remove('hidden');
});
</script>
</body>
</html>"""
return html_content
@app.post("/predict/risk")
def predict_risk(request: RiskAssessmentRequest):
"""Predict patient risk level"""
try:
# Prepare input
chronic_count = len(request.chronic_conditions.split(',')) if request.chronic_conditions else 0
input_data = np.array([[
request.age,
request.bmi,
request.systolic_bp,
request.diastolic_bp,
chronic_count,
request.severity_score
]], dtype=np.float32)
# Run inference
input_name = risk_session.get_inputs()[0].name
result = risk_session.run(None, {input_name: input_data})
# Parse results
prediction = result[0][0]
probabilities = result[1][0] if len(result) > 1 else None
output_classes = MODELS['risk_assessment']['output_classes']
if isinstance(prediction, (int, np.integer)):
prediction_label = output_classes[prediction]
else:
prediction_label = prediction
confidence = float(max(probabilities)) if probabilities is not None else 0.0
return {
'success': True,
'model': 'risk_assessment',
'prediction': prediction_label,
'confidence': confidence,
'probabilities': {
output_classes[i]: float(probabilities[i])
for i in range(len(output_classes))
} if probabilities is not None else None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict/treatment")
def predict_treatment(request: TreatmentOutcomeRequest):
"""Predict treatment outcome"""
try:
# Encode categorical variables
medication_mapping = {
'Paracetamol': 0, 'Ibuprofen': 1, 'Amoxicillin': 2, 'Ciprofloxacin': 3,
'Metformin': 4, 'Lisinopril': 5, 'Amlodipine': 6, 'Omeprazole': 7
}
condition_mapping = {
'Common Cold': 0, 'Influenza': 1, 'Pneumonia': 2, 'Bronchitis': 3,
'Hypertension': 4, 'Diabetes Type 2': 5, 'Migraine': 6, 'Gastroenteritis': 7
}
# Prepare input
input_data = np.array([[
request.patient_age,
request.severity_score,
request.compliance_rate,
medication_mapping.get(request.medication, 0),
condition_mapping.get(request.condition, 0)
]], dtype=np.float32)
# Run inference
input_name = treatment_session.get_inputs()[0].name
result = treatment_session.run(None, {input_name: input_data})
# Parse results
prediction = result[0][0]
probabilities = result[1][0] if len(result) > 1 else None
success_probability = float(probabilities[1]) if probabilities is not None else 0.5
return {
'success': True,
'model': 'treatment_outcome',
'prediction': int(prediction),
'success_probability': round(success_probability * 100, 1),
'confidence': float(max(probabilities)) if probabilities is not None else 0.0,
'probabilities': {
'failure': float(probabilities[0]),
'success': float(probabilities[1])
} if probabilities is not None else None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health_check():
return {
"status": "healthy",
"models_loaded": {
"risk_assessment": risk_session is not None,
"treatment_outcome": treatment_session is not None
}
}