Tegaconsult commited on
Commit
a9ed62f
·
verified ·
1 Parent(s): 9a647e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -398
app.py CHANGED
@@ -1,398 +1,10 @@
1
-
2
- import os
3
- import numpy as np
4
- import onnxruntime as rt
5
- from fastapi import FastAPI, HTTPException
6
- from fastapi.responses import HTMLResponse
7
- from fastapi.staticfiles import StaticFiles
8
- from pydantic import BaseModel
9
- from typing import Dict, Any
10
- from huggingface_hub import hf_hub_download
11
- from dotenv import load_dotenv
12
-
13
- load_dotenv()
14
-
15
- app = FastAPI(title="Digital Doctors Assistant ML API")
16
-
17
- # Model configurations
18
- MODELS = {
19
- 'risk_assessment': {
20
- 'filename': 'risk_assessment.onnx',
21
- 'features': ['age', 'bmi', 'systolic_bp', 'diastolic_bp',
22
- 'chronic_conditions_count', 'severity_score'],
23
- 'output_classes': ['Low', 'Medium', 'High']
24
- },
25
- 'treatment_outcome': {
26
- 'filename': 'treatment_outcome.onnx',
27
- 'features': ['patient_age', 'severity_score', 'compliance_rate',
28
- 'medication_encoded', 'condition_encoded'],
29
- 'output_classes': ['No Success', 'Success']
30
- }
31
- }
32
-
33
- # Load models on startup
34
- risk_session = None
35
- treatment_session = None
36
-
37
- @app.on_event("startup")
38
- async def load_models():
39
- global risk_session, treatment_session
40
-
41
- # Get token from environment (set as Space secret)
42
- token = os.getenv("HUGGINGFACE_TOKEN")
43
-
44
- # Download and load risk assessment model
45
- risk_path = hf_hub_download(
46
- repo_id="Tegaconsult/digital-doctors-assistant-ml",
47
- filename="risk_assessment.onnx",
48
- token=token
49
- )
50
- risk_session = rt.InferenceSession(risk_path)
51
-
52
- # Download and load treatment outcome model
53
- treatment_path = hf_hub_download(
54
- repo_id="Tegaconsult/digital-doctors-assistant-ml",
55
- filename="treatment_outcome.onnx",
56
- token=token
57
- )
58
- treatment_session = rt.InferenceSession(treatment_path)
59
-
60
- print("Models loaded successfully!")
61
-
62
- class RiskAssessmentRequest(BaseModel):
63
- age: float
64
- bmi: float
65
- systolic_bp: float
66
- diastolic_bp: float
67
- chronic_conditions: str = ""
68
- severity_score: float
69
-
70
- class TreatmentOutcomeRequest(BaseModel):
71
- patient_age: float
72
- severity_score: float
73
- compliance_rate: float
74
- medication: str
75
- condition: str
76
-
77
- @app.get("/", response_class=HTMLResponse)
78
- def root():
79
- html_content = """<!DOCTYPE html>
80
- <html lang="en">
81
- <head>
82
- <meta charset="UTF-8">
83
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
84
- <title>Digital Doctors Assistant ML</title>
85
- <style>
86
- * { margin: 0; padding: 0; box-sizing: border-box; }
87
- body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; padding: 20px; }
88
- .container { max-width: 1200px; margin: 0 auto; }
89
- h1 { color: white; text-align: center; margin-bottom: 30px; font-size: 2.5em; }
90
- .cards { display: grid; grid-template-columns: repeat(auto-fit, minmax(500px, 1fr)); gap: 20px; }
91
- .card { background: white; border-radius: 15px; padding: 30px; box-shadow: 0 10px 30px rgba(0,0,0,0.2); }
92
- .card h2 { color: #667eea; margin-bottom: 20px; font-size: 1.8em; }
93
- .form-group { margin-bottom: 15px; }
94
- label { display: block; margin-bottom: 5px; color: #333; font-weight: 600; }
95
- input, textarea { width: 100%; padding: 10px; border: 2px solid #e0e0e0; border-radius: 8px; font-size: 14px; transition: border 0.3s; }
96
- input:focus, textarea:focus { outline: none; border-color: #667eea; }
97
- button { width: 100%; padding: 12px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border: none; border-radius: 8px; font-size: 16px; font-weight: 600; cursor: pointer; transition: transform 0.2s; }
98
- button:hover { transform: translateY(-2px); }
99
- button:active { transform: translateY(0); }
100
- .result { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; border-left: 4px solid #667eea; }
101
- .result h3 { color: #667eea; margin-bottom: 10px; }
102
- .result-item { margin: 8px 0; color: #555; }
103
- .result-item strong { color: #333; }
104
- .error { background: #fee; border-left-color: #f44; }
105
- .error h3 { color: #f44; }
106
- .hidden { display: none; }
107
- .risk-low { color: #28a745; font-weight: bold; }
108
- .risk-medium { color: #ffc107; font-weight: bold; }
109
- .risk-high { color: #dc3545; font-weight: bold; }
110
- </style>
111
- </head>
112
- <body>
113
- <div class="container">
114
- <h1>Digital Doctors Assistant ML</h1>
115
-
116
- <div class="cards">
117
- <div class="card">
118
- <h2>Risk Assessment</h2>
119
- <form id="riskForm">
120
- <div class="form-group">
121
- <label>Age</label>
122
- <input type="number" id="age" required min="0" max="120" value="45">
123
- </div>
124
- <div class="form-group">
125
- <label>BMI</label>
126
- <input type="number" id="bmi" required step="0.1" min="10" max="50" value="28.5">
127
- </div>
128
- <div class="form-group">
129
- <label>Systolic BP</label>
130
- <input type="number" id="systolic_bp" required min="70" max="200" value="140">
131
- </div>
132
- <div class="form-group">
133
- <label>Diastolic BP</label>
134
- <input type="number" id="diastolic_bp" required min="40" max="130" value="90">
135
- </div>
136
- <div class="form-group">
137
- <label>Chronic Conditions (comma-separated)</label>
138
- <input type="text" id="chronic_conditions" placeholder="e.g., diabetes,hypertension" value="diabetes,hypertension">
139
- </div>
140
- <div class="form-group">
141
- <label>Severity Score (0-10)</label>
142
- <input type="number" id="severity_score" required step="0.1" min="0" max="10" value="7.5">
143
- </div>
144
- <button type="submit">Predict Risk</button>
145
- </form>
146
- <div id="riskResult" class="result hidden"></div>
147
- </div>
148
-
149
- <div class="card">
150
- <h2>Treatment Outcome</h2>
151
- <form id="treatmentForm">
152
- <div class="form-group">
153
- <label>Patient Age</label>
154
- <input type="number" id="patient_age" required min="0" max="120" value="55">
155
- </div>
156
- <div class="form-group">
157
- <label>Severity Score (0-10)</label>
158
- <input type="number" id="treatment_severity" required step="0.1" min="0" max="10" value="6.5">
159
- </div>
160
- <div class="form-group">
161
- <label>Compliance Rate (0-1)</label>
162
- <input type="number" id="compliance_rate" required step="0.01" min="0" max="1" value="0.85">
163
- </div>
164
- <div class="form-group">
165
- <label>Medication</label>
166
- <input type="text" id="medication" required list="medications" value="Metformin">
167
- <datalist id="medications">
168
- <option value="Paracetamol">
169
- <option value="Ibuprofen">
170
- <option value="Amoxicillin">
171
- <option value="Ciprofloxacin">
172
- <option value="Metformin">
173
- <option value="Lisinopril">
174
- <option value="Amlodipine">
175
- <option value="Omeprazole">
176
- </datalist>
177
- </div>
178
- <div class="form-group">
179
- <label>Condition</label>
180
- <input type="text" id="condition" required list="conditions" value="Diabetes Type 2">
181
- <datalist id="conditions">
182
- <option value="Common Cold">
183
- <option value="Influenza">
184
- <option value="Pneumonia">
185
- <option value="Bronchitis">
186
- <option value="Hypertension">
187
- <option value="Diabetes Type 2">
188
- <option value="Migraine">
189
- <option value="Gastroenteritis">
190
- </datalist>
191
- </div>
192
- <button type="submit">Predict Outcome</button>
193
- </form>
194
- <div id="treatmentResult" class="result hidden"></div>
195
- </div>
196
- </div>
197
- </div>
198
-
199
- <script>
200
- document.getElementById('riskForm').addEventListener('submit', async (e) => {
201
- e.preventDefault();
202
- const resultDiv = document.getElementById('riskResult');
203
-
204
- const data = {
205
- age: parseFloat(document.getElementById('age').value),
206
- bmi: parseFloat(document.getElementById('bmi').value),
207
- systolic_bp: parseFloat(document.getElementById('systolic_bp').value),
208
- diastolic_bp: parseFloat(document.getElementById('diastolic_bp').value),
209
- chronic_conditions: document.getElementById('chronic_conditions').value,
210
- severity_score: parseFloat(document.getElementById('severity_score').value)
211
- };
212
-
213
- try {
214
- const response = await fetch('/predict/risk', {
215
- method: 'POST',
216
- headers: { 'Content-Type': 'application/json' },
217
- body: JSON.stringify(data)
218
- });
219
-
220
- const result = await response.json();
221
-
222
- if (result.success) {
223
- const riskClass = result.prediction.toLowerCase().replace(' ', '-');
224
- resultDiv.className = 'result';
225
- resultDiv.innerHTML = `
226
- <h3>Prediction Results</h3>
227
- <div class="result-item"><strong>Risk Level:</strong> <span class="risk-${riskClass}">${result.prediction}</span></div>
228
- <div class="result-item"><strong>Confidence:</strong> ${(result.confidence * 100).toFixed(1)}%</div>
229
- ${result.probabilities ? `
230
- <div class="result-item"><strong>Probabilities:</strong></div>
231
- <div class="result-item">Low: ${(result.probabilities.Low * 100).toFixed(1)}%</div>
232
- <div class="result-item">Medium: ${(result.probabilities.Medium * 100).toFixed(1)}%</div>
233
- <div class="result-item">High: ${(result.probabilities.High * 100).toFixed(1)}%</div>
234
- ` : ''}
235
- `;
236
- } else {
237
- throw new Error('Prediction failed');
238
- }
239
- } catch (error) {
240
- resultDiv.className = 'result error';
241
- resultDiv.innerHTML = `<h3>Error</h3><div class="result-item">${error.message}</div>`;
242
- }
243
-
244
- resultDiv.classList.remove('hidden');
245
- });
246
-
247
- document.getElementById('treatmentForm').addEventListener('submit', async (e) => {
248
- e.preventDefault();
249
- const resultDiv = document.getElementById('treatmentResult');
250
-
251
- const data = {
252
- patient_age: parseFloat(document.getElementById('patient_age').value),
253
- severity_score: parseFloat(document.getElementById('treatment_severity').value),
254
- compliance_rate: parseFloat(document.getElementById('compliance_rate').value),
255
- medication: document.getElementById('medication').value,
256
- condition: document.getElementById('condition').value
257
- };
258
-
259
- try {
260
- const response = await fetch('/predict/treatment', {
261
- method: 'POST',
262
- headers: { 'Content-Type': 'application/json' },
263
- body: JSON.stringify(data)
264
- });
265
-
266
- const result = await response.json();
267
-
268
- if (result.success) {
269
- resultDiv.className = 'result';
270
- resultDiv.innerHTML = `
271
- <h3>Prediction Results</h3>
272
- <div class="result-item"><strong>Outcome:</strong> ${result.prediction === 1 ? 'Success' : 'No Success'}</div>
273
- <div class="result-item"><strong>Success Probability:</strong> ${result.success_probability}%</div>
274
- <div class="result-item"><strong>Confidence:</strong> ${(result.confidence * 100).toFixed(1)}%</div>
275
- ${result.probabilities ? `
276
- <div class="result-item"><strong>Probabilities:</strong></div>
277
- <div class="result-item">Failure: ${(result.probabilities.failure * 100).toFixed(1)}%</div>
278
- <div class="result-item">Success: ${(result.probabilities.success * 100).toFixed(1)}%</div>
279
- ` : ''}
280
- `;
281
- } else {
282
- throw new Error('Prediction failed');
283
- }
284
- } catch (error) {
285
- resultDiv.className = 'result error';
286
- resultDiv.innerHTML = `<h3>Error</h3><div class="result-item">${error.message}</div>`;
287
- }
288
-
289
- resultDiv.classList.remove('hidden');
290
- });
291
- </script>
292
- </body>
293
- </html>"""
294
- return html_content
295
-
296
- @app.post("/predict/risk")
297
- def predict_risk(request: RiskAssessmentRequest):
298
- """Predict patient risk level"""
299
- try:
300
- # Prepare input
301
- chronic_count = len(request.chronic_conditions.split(',')) if request.chronic_conditions else 0
302
- input_data = np.array([[
303
- request.age,
304
- request.bmi,
305
- request.systolic_bp,
306
- request.diastolic_bp,
307
- chronic_count,
308
- request.severity_score
309
- ]], dtype=np.float32)
310
-
311
- # Run inference
312
- input_name = risk_session.get_inputs()[0].name
313
- result = risk_session.run(None, {input_name: input_data})
314
-
315
- # Parse results
316
- prediction = result[0][0]
317
- probabilities = result[1][0] if len(result) > 1 else None
318
-
319
- output_classes = MODELS['risk_assessment']['output_classes']
320
- if isinstance(prediction, (int, np.integer)):
321
- prediction_label = output_classes[prediction]
322
- else:
323
- prediction_label = prediction
324
-
325
- confidence = float(max(probabilities)) if probabilities is not None else 0.0
326
-
327
- return {
328
- 'success': True,
329
- 'model': 'risk_assessment',
330
- 'prediction': prediction_label,
331
- 'confidence': confidence,
332
- 'probabilities': {
333
- output_classes[i]: float(probabilities[i])
334
- for i in range(len(output_classes))
335
- } if probabilities is not None else None
336
- }
337
-
338
- except Exception as e:
339
- raise HTTPException(status_code=500, detail=str(e))
340
-
341
- @app.post("/predict/treatment")
342
- def predict_treatment(request: TreatmentOutcomeRequest):
343
- """Predict treatment outcome"""
344
- try:
345
- # Encode categorical variables
346
- medication_mapping = {
347
- 'Paracetamol': 0, 'Ibuprofen': 1, 'Amoxicillin': 2, 'Ciprofloxacin': 3,
348
- 'Metformin': 4, 'Lisinopril': 5, 'Amlodipine': 6, 'Omeprazole': 7
349
- }
350
-
351
- condition_mapping = {
352
- 'Common Cold': 0, 'Influenza': 1, 'Pneumonia': 2, 'Bronchitis': 3,
353
- 'Hypertension': 4, 'Diabetes Type 2': 5, 'Migraine': 6, 'Gastroenteritis': 7
354
- }
355
-
356
- # Prepare input
357
- input_data = np.array([[
358
- request.patient_age,
359
- request.severity_score,
360
- request.compliance_rate,
361
- medication_mapping.get(request.medication, 0),
362
- condition_mapping.get(request.condition, 0)
363
- ]], dtype=np.float32)
364
-
365
- # Run inference
366
- input_name = treatment_session.get_inputs()[0].name
367
- result = treatment_session.run(None, {input_name: input_data})
368
-
369
- # Parse results
370
- prediction = result[0][0]
371
- probabilities = result[1][0] if len(result) > 1 else None
372
-
373
- success_probability = float(probabilities[1]) if probabilities is not None else 0.5
374
-
375
- return {
376
- 'success': True,
377
- 'model': 'treatment_outcome',
378
- 'prediction': int(prediction),
379
- 'success_probability': round(success_probability * 100, 1),
380
- 'confidence': float(max(probabilities)) if probabilities is not None else 0.0,
381
- 'probabilities': {
382
- 'failure': float(probabilities[0]),
383
- 'success': float(probabilities[1])
384
- } if probabilities is not None else None
385
- }
386
-
387
- except Exception as e:
388
- raise HTTPException(status_code=500, detail=str(e))
389
-
390
- @app.get("/health")
391
- def health_check():
392
- return {
393
- "status": "healthy",
394
- "models_loaded": {
395
- "risk_assessment": risk_session is not None,
396
- "treatment_outcome": treatment_session is not None
397
- }
398
- }
 
1
+ """
2
+ Alternative entry point for the application.
3
+ This file can be used instead of ml.py if needed.
4
+ """
5
+
6
+ from ml import app
7
+
8
+ if __name__ == "__main__":
9
+ import uvicorn
10
+ uvicorn.run(app, host="0.0.0.0", port=7860)