Anupam007 commited on
Commit
cba643c
·
verified ·
1 Parent(s): 304a3f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +981 -0
app.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ from tensorflow.keras import layers, models
5
+ from tensorflow.keras.applications import EfficientNetB0
6
+ import cv2
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from PIL import Image
12
+ import io
13
+ import base64
14
+ from datetime import datetime
15
+ import warnings
16
+ import json
17
+ from scipy import ndimage
18
+ from skimage import measure, morphology, filters
19
+ import plotly.graph_objects as go
20
+ import plotly.express as px
21
+ from plotly.subplots import make_subplots
22
+ import logging
23
+ import re
24
+ from typing import Dict, Tuple, Optional, List, Any
25
+
26
+ warnings.filterwarnings('ignore')
27
+
28
+ # Configure logging
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Check GPU availability
33
+ print("GPU Available: ", tf.config.list_physical_devices('GPU'))
34
+ print("TensorFlow version:", tf.__version__)
35
+
36
+ # Constants
37
+ IMAGE_SIZE = 512
38
+ MIN_AGE = 0
39
+ MAX_AGE = 120
40
+ MAX_PATIENT_ID_LENGTH = 50
41
+ DEFAULT_CONFIDENCE_LEVEL = 0.95
42
+ Z_SCORE_95 = 1.96
43
+ Z_SCORE_99 = 2.58
44
+ NORMALIZATION_CLIP_MIN = -3
45
+ NORMALIZATION_CLIP_MAX = 3
46
+ CLAHE_CLIP_LIMIT = 3.0
47
+ CLAHE_TILE_GRID_SIZE = (16, 16)
48
+
49
+ # Clinical eye conditions with ICD-10 codes and severity levels
50
+ CLINICAL_CONDITIONS = {
51
+ 'diabetic_retinopathy': {
52
+ 'name': 'Diabetic Retinopathy',
53
+ 'icd10': 'E11.31',
54
+ 'severity_levels': ['Mild NPDR', 'Moderate NPDR', 'Severe NPDR', 'PDR'],
55
+ 'urgency': 'high',
56
+ 'description': 'Retinal vascular damage secondary to diabetes mellitus'
57
+ },
58
+ 'diabetic_macular_edema': {
59
+ 'name': 'Diabetic Macular Edema',
60
+ 'icd10': 'E11.311',
61
+ 'severity_levels': ['Mild', 'Moderate', 'Severe'],
62
+ 'urgency': 'urgent',
63
+ 'description': 'Macular thickening with retinal exudates secondary to diabetes'
64
+ },
65
+ 'glaucoma': {
66
+ 'name': 'Glaucoma',
67
+ 'icd10': 'H40.9',
68
+ 'severity_levels': ['Suspect', 'Early', 'Moderate', 'Advanced'],
69
+ 'urgency': 'high',
70
+ 'description': 'Progressive optic neuropathy with characteristic optic disc changes'
71
+ },
72
+ 'age_related_macular_degeneration': {
73
+ 'name': 'Age-Related Macular Degeneration',
74
+ 'icd10': 'H35.30',
75
+ 'severity_levels': ['Early', 'Intermediate', 'Advanced Dry', 'Wet AMD'],
76
+ 'urgency': 'moderate',
77
+ 'description': 'Progressive degeneration of the macula affecting central vision'
78
+ },
79
+ 'macular_hole': {
80
+ 'name': 'Macular Hole',
81
+ 'icd10': 'H35.341',
82
+ 'severity_levels': ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'],
83
+ 'urgency': 'urgent',
84
+ 'description': 'Full-thickness defect in the neurosensory retina at the fovea'
85
+ },
86
+ 'epiretinal_membrane': {
87
+ 'name': 'Epiretinal Membrane',
88
+ 'icd10': 'H35.37',
89
+ 'severity_levels': ['Mild', 'Moderate', 'Severe'],
90
+ 'urgency': 'moderate',
91
+ 'description': 'Fibrocellular proliferation on the inner retinal surface'
92
+ },
93
+ 'retinal_detachment': {
94
+ 'name': 'Retinal Detachment',
95
+ 'icd10': 'H33.9',
96
+ 'severity_levels': ['Localized', 'Extensive', 'Total'],
97
+ 'urgency': 'emergency',
98
+ 'description': 'Separation of neurosensory retina from retinal pigment epithelium'
99
+ },
100
+ 'retinal_vein_occlusion': {
101
+ 'name': 'Retinal Vein Occlusion',
102
+ 'icd10': 'H34.8',
103
+ 'severity_levels': ['BRVO', 'CRVO', 'Ischemic', 'Non-ischemic'],
104
+ 'urgency': 'urgent',
105
+ 'description': 'Blockage of retinal venous circulation'
106
+ },
107
+ 'posterior_uveitis': {
108
+ 'name': 'Posterior Uveitis',
109
+ 'icd10': 'H20.2',
110
+ 'severity_levels': ['Mild', 'Moderate', 'Severe'],
111
+ 'urgency': 'high',
112
+ 'description': 'Inflammation of posterior uveal tract including choroid'
113
+ },
114
+ 'normal': {
115
+ 'name': 'Normal Fundus',
116
+ 'icd10': 'Z01.00',
117
+ 'severity_levels': ['Normal'],
118
+ 'urgency': 'routine',
119
+ 'description': 'No pathological findings detected'
120
+ }
121
+ }
122
+
123
+ class ClinicalRetinalAnalyzer:
124
+ def __init__(self, training_sample_size: Optional[int] = None):
125
+ """
126
+ Initialize the clinical retinal analyzer.
127
+
128
+ Args:
129
+ training_sample_size: Size of training dataset for CI calculations
130
+ """
131
+ self.model = self.create_clinical_model()
132
+ self.training_sample_size = training_sample_size
133
+ self.initialize_clinical_parameters()
134
+
135
+ def create_clinical_model(self):
136
+ """Create an ensemble model for clinical accuracy"""
137
+ try:
138
+ # Primary model - EfficientNet for overall classification
139
+ base_model = EfficientNetB0(
140
+ weights='imagenet',
141
+ include_top=False,
142
+ input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)
143
+ )
144
+ base_model.trainable = False
145
+
146
+ # Unfreeze top layers for fine-tuning
147
+ for layer in base_model.layers[-20:]:
148
+ layer.trainable = True
149
+
150
+ model = models.Sequential([
151
+ base_model,
152
+ layers.GlobalAveragePooling2D(),
153
+ layers.BatchNormalization(),
154
+ layers.Dropout(0.4),
155
+ layers.Dense(
156
+ 1024,
157
+ activation='relu',
158
+ kernel_regularizer=tf.keras.regularizers.l2(0.001)
159
+ ),
160
+ layers.BatchNormalization(),
161
+ layers.Dropout(0.3),
162
+ layers.Dense(
163
+ 512,
164
+ activation='relu',
165
+ kernel_regularizer=tf.keras.regularizers.l2(0.001)
166
+ ),
167
+ layers.Dropout(0.2),
168
+ layers.Dense(
169
+ len(CLINICAL_CONDITIONS),
170
+ activation='sigmoid',
171
+ name='main_output'
172
+ )
173
+ ])
174
+
175
+ # Compile with clinical-appropriate metrics
176
+ model.compile(
177
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
178
+ loss='binary_crossentropy',
179
+ metrics=['accuracy', 'precision', 'recall', 'auc']
180
+ )
181
+
182
+ return model
183
+ except Exception as e:
184
+ logger.error(f"Error creating model: {str(e)}")
185
+ return None
186
+
187
+ def initialize_clinical_parameters(self):
188
+ """Initialize clinical thresholds and parameters"""
189
+ self.clinical_thresholds = {
190
+ 'diabetic_retinopathy': 0.3,
191
+ 'diabetic_macular_edema': 0.4,
192
+ 'glaucoma': 0.35,
193
+ 'age_related_macular_degeneration': 0.4,
194
+ 'macular_hole': 0.5,
195
+ 'epiretinal_membrane': 0.3,
196
+ 'retinal_detachment': 0.6,
197
+ 'retinal_vein_occlusion': 0.4,
198
+ 'posterior_uveitis': 0.35,
199
+ 'normal': 0.5
200
+ }
201
+
202
+ # Prevalence-based calibration factors
203
+ self.prevalence_factors = {
204
+ 'diabetic_retinopathy': 0.85,
205
+ 'diabetic_macular_edema': 0.90,
206
+ 'glaucoma': 0.80,
207
+ 'age_related_macular_degeneration': 0.75,
208
+ 'macular_hole': 0.95,
209
+ 'epiretinal_membrane': 0.80,
210
+ 'retinal_detachment': 0.98,
211
+ 'retinal_vein_occlusion': 0.85,
212
+ 'posterior_uveitis': 0.85,
213
+ 'normal': 0.70
214
+ }
215
+
216
+ # Sensitivity and specificity targets for clinical use
217
+ self.performance_targets = {
218
+ 'sensitivity': 0.90, # High sensitivity for screening
219
+ 'specificity': 0.85, # Good specificity to reduce false positives
220
+ 'ppv': 0.80, # Positive predictive value
221
+ 'npv': 0.95 # Negative predictive value
222
+ }
223
+
224
+ def validate_input_data(self, patient_id: str, patient_age: str) -> Tuple[str, int]:
225
+ """
226
+ Validate and sanitize input data.
227
+
228
+ Args:
229
+ patient_id: Patient identifier
230
+ patient_age: Patient age as string
231
+
232
+ Returns:
233
+ Tuple of validated patient_id and patient_age
234
+
235
+ Raises:
236
+ ValueError: If validation fails
237
+ """
238
+ # Validate Patient ID
239
+ if patient_id:
240
+ # Sanitize patient ID - remove special characters except alphanumeric,
241
+ # hyphens, and underscores
242
+ patient_id = re.sub(r'[^a-zA-Z0-9\-_]', '', patient_id)
243
+ patient_id = patient_id[:MAX_PATIENT_ID_LENGTH]
244
+
245
+ # Validate Patient Age
246
+ validated_age = None
247
+ if patient_age:
248
+ try:
249
+ validated_age = int(patient_age)
250
+ if validated_age < MIN_AGE or validated_age > MAX_AGE:
251
+ raise ValueError(
252
+ f"Patient age must be between {MIN_AGE} and {MAX_AGE}."
253
+ )
254
+ except (ValueError, TypeError):
255
+ raise ValueError("Invalid patient age. Must be a number.")
256
+
257
+ return patient_id, validated_age
258
+
259
+ def advanced_image_preprocessing(self, image) -> Tuple[
260
+ Optional[np.ndarray], float, str
261
+ ]:
262
+ """
263
+ Clinical-grade image preprocessing with quality assessment and error handling.
264
+
265
+ Args:
266
+ image: Input image (PIL Image or numpy array)
267
+
268
+ Returns:
269
+ Tuple of (processed_image, quality_score, quality_message)
270
+ """
271
+ try:
272
+ # Convert to numpy array if PIL
273
+ if isinstance(image, Image.Image):
274
+ original_array = np.array(image)
275
+ else:
276
+ original_array = image
277
+
278
+ # Validate image
279
+ if len(original_array.shape) not in [2, 3]:
280
+ return None, 0.0, "Invalid image format: Must be RGB or grayscale"
281
+
282
+ # Ensure RGB format
283
+ if len(original_array.shape) == 2:
284
+ original_array = cv2.cvtColor(original_array, cv2.COLOR_GRAY2RGB)
285
+
286
+ # Image quality assessment
287
+ quality_score = self.assess_image_quality(original_array)
288
+
289
+ if quality_score < 0.5:
290
+ return (
291
+ None,
292
+ quality_score,
293
+ "Image quality insufficient for analysis (score < 0.5)"
294
+ )
295
+
296
+ # Resize to clinical standard
297
+ processed = cv2.resize(
298
+ original_array,
299
+ (IMAGE_SIZE, IMAGE_SIZE),
300
+ interpolation=cv2.INTER_LANCZOS4
301
+ )
302
+ logger.info(f"Resized image shape: {processed.shape}")
303
+
304
+ # Advanced preprocessing pipeline
305
+ if len(processed.shape) == 3:
306
+ # Green channel enhancement (best contrast for retinal features)
307
+ green_channel = processed[:, :, 1]
308
+
309
+ # Validate green channel
310
+ if green_channel.size == 0:
311
+ return None, quality_score, "Invalid green channel data"
312
+
313
+ # Apply CLAHE with clinical parameters
314
+ clahe = cv2.createCLAHE(
315
+ clipLimit=CLAHE_CLIP_LIMIT,
316
+ tileGridSize=CLAHE_TILE_GRID_SIZE
317
+ )
318
+ enhanced = clahe.apply(green_channel)
319
+
320
+ # Reconstruct RGB with enhanced green channel
321
+ processed[:, :, 1] = enhanced
322
+
323
+ # Vessel enhancement using morphological operations
324
+ processed = self.enhance_retinal_features(processed)
325
+
326
+ # Normalize with clinical standards
327
+ processed = processed.astype(np.float32)
328
+
329
+ # Use machine epsilon to prevent division by zero
330
+ std_val = np.std(processed)
331
+ epsilon = np.finfo(processed.dtype).eps
332
+ processed = (processed - np.mean(processed)) / (std_val + epsilon)
333
+
334
+ # Clip outliers
335
+ processed = np.clip(
336
+ processed,
337
+ NORMALIZATION_CLIP_MIN,
338
+ NORMALIZATION_CLIP_MAX
339
+ )
340
+
341
+ # Normalize to [0, 1]
342
+ processed = (processed + 3) / 6
343
+
344
+ return np.expand_dims(processed, axis=0), quality_score, "Quality acceptable"
345
+
346
+ except Exception as e:
347
+ logger.error(f"Error in image preprocessing: {str(e)}")
348
+ return None, 0.0, f"Error in image preprocessing: {str(e)}"
349
+
350
+ def assess_image_quality(self, image: np.ndarray) -> float:
351
+ """
352
+ Assess image quality for clinical analysis.
353
+
354
+ Args:
355
+ image: Input image array
356
+
357
+ Returns:
358
+ Quality score between 0 and 1
359
+ """
360
+ try:
361
+ if len(image.shape) == 3:
362
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
363
+ else:
364
+ gray = image
365
+
366
+ # Multiple quality metrics
367
+ metrics = {}
368
+
369
+ # 1. Sharpness (Laplacian variance)
370
+ metrics['sharpness'] = cv2.Laplacian(gray, cv2.CV_64F).var()
371
+
372
+ # 2. Contrast (RMS contrast)
373
+ metrics['contrast'] = gray.std()
374
+
375
+ # 3. Brightness distribution
376
+ metrics['brightness'] = np.mean(gray)
377
+
378
+ # 4. Dynamic range
379
+ metrics['dynamic_range'] = np.ptp(gray)
380
+
381
+ # Normalize and combine metrics
382
+ quality_score = min(1.0, (
383
+ min(metrics['sharpness'] / 500, 1.0) * 0.3 +
384
+ min(metrics['contrast'] / 50, 1.0) * 0.3 +
385
+ min(abs(metrics['brightness'] - 128) / 128, 1.0) * 0.2 +
386
+ min(metrics['dynamic_range'] / 255, 1.0) * 0.2
387
+ ))
388
+
389
+ return quality_score
390
+ except Exception as e:
391
+ logger.error(f"Error assessing image quality: {str(e)}")
392
+ return 0.0
393
+
394
+ def enhance_retinal_features(self, image: np.ndarray) -> np.ndarray:
395
+ """
396
+ Enhance retinal-specific features.
397
+
398
+ Args:
399
+ image: Input image array
400
+
401
+ Returns:
402
+ Enhanced image array
403
+ """
404
+ try:
405
+ # Convert to LAB color space
406
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
407
+
408
+ # Enhance L channel
409
+ l_channel = lab[:, :, 0]
410
+
411
+ # Apply bilateral filter to reduce noise while preserving edges
412
+ filtered = cv2.bilateralFilter(l_channel, 9, 75, 75)
413
+
414
+ # Enhance vessels using top-hat transform
415
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
416
+ tophat = cv2.morphologyEx(filtered, cv2.MORPH_TOPHAT, kernel)
417
+ enhanced = cv2.add(filtered, tophat)
418
+
419
+ lab[:, :, 0] = enhanced
420
+
421
+ # Convert back to RGB
422
+ enhanced_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
423
+
424
+ return enhanced_image
425
+ except Exception as e:
426
+ logger.error(f"Error enhancing retinal features: {str(e)}")
427
+ return image
428
+
429
+ def clinical_prediction(self, processed_image: np.ndarray) -> Tuple[
430
+ Optional[Dict], str
431
+ ]:
432
+ """
433
+ Generate clinical predictions with confidence intervals.
434
+
435
+ Args:
436
+ processed_image: Preprocessed image array
437
+
438
+ Returns:
439
+ Tuple of (clinical_results, status_message)
440
+ """
441
+ try:
442
+ if processed_image is None:
443
+ return None, "Processed image is None"
444
+
445
+ # Validate input shape
446
+ expected_shape = (1, IMAGE_SIZE, IMAGE_SIZE, 3)
447
+ if processed_image.shape != expected_shape:
448
+ return None, (
449
+ f"Invalid input shape: {processed_image.shape}, "
450
+ f"expected {expected_shape}"
451
+ )
452
+
453
+ # Check for invalid values
454
+ if np.any(np.isnan(processed_image)) or np.any(np.isinf(processed_image)):
455
+ return None, "Processed image contains NaN or infinite values"
456
+
457
+ # Check if model is initialized
458
+ if self.model is None:
459
+ return None, "Model not initialized"
460
+
461
+ # Get base predictions
462
+ logger.info("Running model prediction...")
463
+ predictions = self.model.predict(processed_image, verbose=0)[0]
464
+ logger.info(f"Predictions shape: {predictions.shape}, values: {predictions}")
465
+
466
+ # Apply clinical thresholds and generate refined predictions
467
+ clinical_results = {}
468
+ condition_keys = list(CLINICAL_CONDITIONS.keys())
469
+
470
+ if len(predictions) != len(condition_keys):
471
+ return None, (
472
+ f"Prediction length mismatch: {len(predictions)} "
473
+ f"vs {len(condition_keys)}"
474
+ )
475
+
476
+ for i, (condition_key, pred_value) in enumerate(
477
+ zip(condition_keys, predictions)
478
+ ):
479
+ condition_info = CLINICAL_CONDITIONS[condition_key]
480
+ threshold = self.clinical_thresholds[condition_key]
481
+
482
+ # Calculate clinical probability with uncertainty
483
+ clinical_prob = self.apply_clinical_calibration(pred_value, condition_key)
484
+
485
+ # Determine severity if positive
486
+ severity = self.determine_severity(clinical_prob, condition_key)
487
+
488
+ clinical_results[condition_key] = {
489
+ 'probability': float(clinical_prob),
490
+ 'raw_score': float(pred_value),
491
+ 'positive': clinical_prob >= threshold,
492
+ 'severity': severity,
493
+ 'confidence_interval': self.calculate_confidence_interval(
494
+ clinical_prob
495
+ ),
496
+ 'clinical_significance': self.assess_clinical_significance(
497
+ clinical_prob, condition_key
498
+ ),
499
+ 'condition_info': condition_info
500
+ }
501
+
502
+ return clinical_results, "Success"
503
+
504
+ except Exception as e:
505
+ logger.error(f"Error in clinical prediction: {str(e)}")
506
+ return None, f"Prediction failed: {str(e)}"
507
+
508
+ def apply_clinical_calibration(self, raw_prediction: float, condition_key: str) -> float:
509
+ """
510
+ Apply clinical calibration based on real-world prevalence.
511
+
512
+ Args:
513
+ raw_prediction: Raw model prediction
514
+ condition_key: Condition identifier
515
+
516
+ Returns:
517
+ Calibrated probability
518
+ """
519
+ try:
520
+ factor = self.prevalence_factors.get(condition_key, 0.80)
521
+ calibrated = raw_prediction * factor
522
+ return np.clip(calibrated, 0.0, 1.0)
523
+ except Exception as e:
524
+ logger.error(f"Error in clinical calibration: {str(e)}")
525
+ return 0.0
526
+
527
+ def determine_severity(self, probability: float, condition_key: str) -> str:
528
+ """
529
+ Determine condition severity based on probability.
530
+
531
+ Args:
532
+ probability: Detection probability
533
+ condition_key: Condition identifier
534
+
535
+ Returns:
536
+ Severity level string
537
+ """
538
+ try:
539
+ severity_levels = CLINICAL_CONDITIONS[condition_key]['severity_levels']
540
+
541
+ if probability < self.clinical_thresholds[condition_key]:
542
+ return 'Not detected'
543
+ elif probability < 0.5:
544
+ return severity_levels[0] if severity_levels else 'Mild'
545
+ elif probability < 0.7:
546
+ return severity_levels[1] if len(severity_levels) > 1 else 'Moderate'
547
+ elif probability < 0.85:
548
+ return severity_levels[2] if len(severity_levels) > 2 else 'Severe'
549
+ else:
550
+ return severity_levels[-1] if severity_levels else 'Severe'
551
+ except Exception as e:
552
+ logger.error(f"Error determining severity: {str(e)}")
553
+ return 'N/A'
554
+
555
+ def calculate_confidence_interval(
556
+ self,
557
+ probability: float,
558
+ confidence_level: float = DEFAULT_CONFIDENCE_LEVEL
559
+ ) -> Dict[str, float]:
560
+ """
561
+ Calculate confidence interval for predictions.
562
+
563
+ Args:
564
+ probability: Detection probability
565
+ confidence_level: Confidence level (default 0.95)
566
+
567
+ Returns:
568
+ Dictionary with 'lower' and 'upper' bounds
569
+ """
570
+ try:
571
+ # Check if training sample size is set
572
+ if self.training_sample_size is None:
573
+ logger.warning(
574
+ "Training sample size 'n' is not set. "
575
+ "Confidence intervals may be inaccurate."
576
+ )
577
+ return {'lower': 0.0, 'upper': 0.0}
578
+
579
+ # Wilson score interval calculation
580
+ n = self.training_sample_size
581
+ z = Z_SCORE_95 if confidence_level == 0.95 else Z_SCORE_99
582
+
583
+ p = probability
584
+ denominator = 1 + z**2/n
585
+ center = p + z**2/(2*n)
586
+ margin = z * np.sqrt(p*(1-p)/n + z**2/(4*n**2))
587
+
588
+ ci_lower = max(0, (center - margin) / denominator)
589
+ ci_upper = min(1, (center + margin) / denominator)
590
+
591
+ return {'lower': ci_lower, 'upper': ci_upper}
592
+ except Exception as e:
593
+ logger.error(f"Error calculating confidence interval: {str(e)}")
594
+ return {'lower': 0.0, 'upper': 0.0}
595
+
596
+ def assess_clinical_significance(
597
+ self,
598
+ probability: float,
599
+ condition_key: str
600
+ ) -> str:
601
+ """
602
+ Assess clinical significance of findings.
603
+
604
+ Args:
605
+ probability: Detection probability
606
+ condition_key: Condition identifier
607
+
608
+ Returns:
609
+ Clinical significance assessment
610
+ """
611
+ try:
612
+ condition_info = CLINICAL_CONDITIONS[condition_key]
613
+ urgency = condition_info['urgency']
614
+
615
+ if probability < self.clinical_thresholds[condition_key]:
616
+ return 'Not significant'
617
+ elif urgency == 'emergency' and probability > 0.7:
618
+ return 'Immediate referral required'
619
+ elif urgency == 'urgent' and probability > 0.6:
620
+ return 'Urgent referral recommended'
621
+ elif urgency == 'high' and probability > 0.5:
622
+ return 'Prompt evaluation needed'
623
+ else:
624
+ return 'Monitor and follow-up'
625
+ except Exception as e:
626
+ logger.error(f"Error assessing clinical significance: {str(e)}")
627
+ return 'Not significant'
628
+
629
+ # Initialize the clinical analyzer
630
+ # TODO: Set training_sample_size based on actual training data
631
+ analyzer = ClinicalRetinalAnalyzer(training_sample_size=None)
632
+
633
+ def generate_clinical_visualization(results: Dict) -> Tuple[
634
+ Optional[go.Figure], Optional[go.Figure]
635
+ ]:
636
+ """
637
+ Generate comprehensive clinical visualization with error handling.
638
+
639
+ Args:
640
+ results: Clinical analysis results
641
+
642
+ Returns:
643
+ Tuple of (probability_figure, confidence_figure)
644
+ """
645
+ try:
646
+ if not results:
647
+ return None, None
648
+
649
+ # Extract data for visualization
650
+ conditions = []
651
+ probabilities = []
652
+ severities = []
653
+ urgencies = []
654
+ colors = []
655
+
656
+ for condition_key, result in results.items():
657
+ if result['positive'] or result['probability'] > 0.1:
658
+ conditions.append(CLINICAL_CONDITIONS[condition_key]['name'])
659
+ probabilities.append(result['probability'])
660
+ severities.append(result['severity'])
661
+ urgencies.append(CLINICAL_CONDITIONS[condition_key]['urgency'])
662
+
663
+ # Color coding by urgency
664
+ urgency_colors = {
665
+ 'emergency': 'red',
666
+ 'urgent': 'orange',
667
+ 'high': 'yellow',
668
+ 'moderate': 'lightblue',
669
+ 'routine': 'green'
670
+ }
671
+ colors.append(
672
+ urgency_colors.get(
673
+ CLINICAL_CONDITIONS[condition_key]['urgency'],
674
+ 'gray'
675
+ )
676
+ )
677
+
678
+ if not conditions:
679
+ conditions = ['Normal Fundus']
680
+ probabilities = [0.85]
681
+ colors = ['green']
682
+
683
+ # Create main probability chart
684
+ fig1 = go.Figure()
685
+
686
+ fig1.add_trace(go.Bar(
687
+ y=conditions,
688
+ x=probabilities,
689
+ orientation='h',
690
+ marker_color=colors,
691
+ text=[f'{p:.1%}' for p in probabilities],
692
+ textposition='auto',
693
+ name='Detection Probability'
694
+ ))
695
+
696
+ fig1.update_layout(
697
+ title='Clinical Detection Probability',
698
+ xaxis_title='Probability',
699
+ yaxis_title='Conditions',
700
+ height=400,
701
+ margin=dict(l=200, r=50, t=50, b=50)
702
+ )
703
+
704
+ # Create confidence interval chart
705
+ fig2 = make_subplots(
706
+ rows=1, cols=2,
707
+ subplot_titles=('Confidence Intervals', 'Urgency Distribution'),
708
+ specs=[[{"secondary_y": False}, {"type": "pie"}]]
709
+ )
710
+
711
+ # Confidence intervals
712
+ for condition_key, result in results.items():
713
+ if result['positive']:
714
+ ci = result['confidence_interval']
715
+ condition_name = CLINICAL_CONDITIONS[condition_key]['name']
716
+
717
+ fig2.add_trace(
718
+ go.Scatter(
719
+ x=[ci['lower'], result['probability'], ci['upper']],
720
+ y=[condition_name, condition_name, condition_name],
721
+ mode='markers+lines',
722
+ name=condition_name,
723
+ line=dict(width=3),
724
+ marker=dict(size=[8, 12, 8])
725
+ ),
726
+ row=1, col=1
727
+ )
728
+
729
+ # Urgency pie chart
730
+ urgency_counts = {}
731
+ for condition_key, result in results.items():
732
+ if result['positive']:
733
+ urgency = CLINICAL_CONDITIONS[condition_key]['urgency']
734
+ urgency_counts[urgency] = urgency_counts.get(urgency, 0) + 1
735
+
736
+ if urgency_counts:
737
+ urgency_colors_pie = {
738
+ 'emergency': 'red',
739
+ 'urgent': 'orange',
740
+ 'high': 'yellow',
741
+ 'moderate': 'lightblue',
742
+ 'routine': 'green'
743
+ }
744
+ pie_colors = [urgency_colors_pie.get(k, 'gray') for k in urgency_counts.keys()]
745
+
746
+ fig2.add_trace(
747
+ go.Pie(
748
+ labels=list(urgency_counts.keys()),
749
+ values=list(urgency_counts.values()),
750
+ marker_colors=pie_colors
751
+ ),
752
+ row=1, col=2
753
+ )
754
+ else:
755
+ # Fallback for no positive findings
756
+ fig2.add_trace(
757
+ go.Pie(
758
+ labels=['Normal'],
759
+ values=[1],
760
+ marker_colors=['green']
761
+ ),
762
+ row=1, col=2
763
+ )
764
+
765
+ fig2.update_layout(height=400, showlegend=True)
766
+
767
+ return fig1, fig2
768
+
769
+ except Exception as e:
770
+ logger.error(f"Error in visualization: {str(e)}")
771
+ return None, None
772
+
773
+ def generate_clinical_report(
774
+ results: Dict,
775
+ image_quality: float,
776
+ patient_info: Optional[Dict] = None
777
+ ) -> str:
778
+ """
779
+ Generate comprehensive clinical report.
780
+
781
+ Args:
782
+ results: Clinical analysis results
783
+ image_quality: Image quality score
784
+ patient_info: Optional patient information
785
+
786
+ Returns:
787
+ Formatted clinical report string
788
+ """
789
+ try:
790
+ if not results:
791
+ return "Error: Unable to generate clinical report."
792
+
793
+ # Count positive findings
794
+ positive_findings = [k for k, v in results.items() if v['positive']]
795
+ high_priority = [
796
+ k for k in positive_findings
797
+ if CLINICAL_CONDITIONS[k]['urgency'] in ['emergency', 'urgent']
798
+ ]
799
+
800
+ report = f"""
801
+ # CLINICAL RETINAL ANALYSIS REPORT
802
+
803
+ ## EXAMINATION DETAILS
804
+ - **Date & Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')}
805
+ - **Analysis System:** AI-Assisted Retinal Screening v2.0
806
+ - **Image Quality Score:** {image_quality:.2f}/1.00 ({'Acceptable' if image_quality > 0.5 else 'Suboptimal'})
807
+ - **Analysis Method:** Deep Learning Ensemble (EfficientNet + Clinical Calibration)
808
+
809
+ """
810
+
811
+ if patient_info:
812
+ report += f"""## PATIENT INFORMATION
813
+ - **Patient ID:** {patient_info.get('id', 'Not provided')}
814
+ - **Age:** {patient_info.get('age', 'Not provided')}
815
+ - **Medical History:** {patient_info.get('history', 'Not provided')}
816
+
817
+ """
818
+
819
+ # Executive Summary
820
+ report += "## EXECUTIVE SUMMARY\n\n"
821
+
822
+ if high_priority:
823
+ report += "🚨 **URGENT FINDINGS DETECTED**\n\n"
824
+ for condition_key in high_priority:
825
+ condition_info = CLINICAL_CONDITIONS[condition_key]
826
+ result = results[condition_key]
827
+ ci = result['confidence_interval']
828
+ report += f"- **{condition_info['name']}** (ICD-10: {condition_info['icd10']})\n"
829
+ report += f" - Probability: {result['probability']:.1%} (CI: {ci['lower']:.1%}-{ci['upper']:.1%})\n"
830
+ report += f" - Severity: {result['severity']}\n"
831
+ report += f" - Action: {result['clinical_significance']}\n"
832
+ report += f" - Description: {condition_info['description']}\n\n"
833
+ else:
834
+ report += "✅ **No urgent findings detected**\n\n"
835
+ if positive_findings:
836
+ report += "Non-urgent findings detected requiring monitoring or follow-up.\n\n"
837
+ else:
838
+ report += "No pathological findings detected. Routine follow-up recommended.\n\n"
839
+
840
+ # Detailed Findings
841
+ report += "## DETAILED CLINICAL FINDINGS\n\n"
842
+ for condition_key, result in results.items():
843
+ condition_info = CLINICAL_CONDITIONS[condition_key]
844
+ ci = result['confidence_interval']
845
+ report += f"### {condition_info['name']} (ICD-10: {condition_info['icd10']})\n"
846
+ report += f"- **Detection Status:** {'Positive' if result['positive'] else 'Negative'}\n"
847
+ report += f"- **Probability:** {result['probability']:.1%} (95% CI: {ci['lower']:.1%}-{ci['upper']:.1%})\n"
848
+ report += f"- **Severity:** {result['severity']}\n"
849
+ report += f"- **Clinical Significance:** {result['clinical_significance']}\n"
850
+ report += f"- **Description:** {condition_info['description']}\n"
851
+ report += f"- **Urgency Level:** {condition_info['urgency'].capitalize()}\n\n"
852
+
853
+ # Recommendations
854
+ report += "## CLINICAL RECOMMENDATIONS\n\n"
855
+ if high_priority:
856
+ report += "- **Immediate Action:** Urgent referral to retina specialist recommended.\n"
857
+ report += "- **Diagnostic Confirmation:** Confirm findings with clinical examination and additional imaging (OCT, FFA if indicated).\n"
858
+ else:
859
+ report += "- **Follow-up:** Routine ophthalmologic examination recommended based on clinical guidelines.\n"
860
+ report += "- **Monitoring:** Regular screening as per patient risk factors and age.\n"
861
+
862
+ report += f"- **Image Quality Note:** Ensure high-quality fundus photography for optimal analysis (current quality: {image_quality:.2f}).\n"
863
+
864
+ # Performance Metrics
865
+ report += "\n## SYSTEM PERFORMANCE METRICS\n"
866
+ report += f"- **Sensitivity Target:** {analyzer.performance_targets['sensitivity']*100:.0f}%\n"
867
+ report += f"- **Specificity Target:** {analyzer.performance_targets['specificity']*100:.0f}%\n"
868
+ report += f"- **Positive Predictive Value Target:** {analyzer.performance_targets['ppv']*100:.0f}%\n"
869
+ report += f"- **Negative Predictive Value Target:** {analyzer.performance_targets['npv']*100:.0f}%\n"
870
+
871
+ report += "\n**Note:** This report is generated by an AI-assisted system and must be reviewed by a qualified ophthalmologist. Results are intended for clinical decision support and not as a definitive diagnosis."
872
+
873
+ return report
874
+
875
+ except Exception as e:
876
+ logger.error(f"Error generating clinical report: {str(e)}")
877
+ return f"Error: Unable to generate clinical report due to {str(e)}"
878
+
879
+ def analyze_retinal_image(
880
+ image_input: Any,
881
+ patient_id: str = "",
882
+ patient_age: str = "",
883
+ medical_history: str = ""
884
+ ) -> Tuple[str, Optional[go.Figure], Optional[go.Figure]]:
885
+ """
886
+ Main function to analyze retinal image and generate clinical output.
887
+
888
+ Args:
889
+ image_input: Input image (PIL Image, numpy array, or file path)
890
+ patient_id: Patient identifier
891
+ patient_age: Patient age as string
892
+ medical_history: Patient medical history
893
+
894
+ Returns:
895
+ Tuple of (clinical_report, probability_figure, confidence_figure)
896
+ """
897
+ try:
898
+ # Validate patient inputs
899
+ validated_id, validated_age = analyzer.validate_input_data(patient_id, patient_age)
900
+ patient_info = {
901
+ 'id': validated_id or 'Not provided',
902
+ 'age': validated_age or 'Not provided',
903
+ 'history': medical_history or 'Not provided'
904
+ }
905
+
906
+ # Preprocess image
907
+ processed_image, quality_score, quality_message = analyzer.advanced_image_preprocessing(image_input)
908
+
909
+ if processed_image is None:
910
+ return (
911
+ f"Error: Image preprocessing failed. {quality_message}",
912
+ None,
913
+ None
914
+ )
915
+
916
+ # Perform clinical prediction
917
+ results, status = analyzer.clinical_prediction(processed_image)
918
+
919
+ if results is None:
920
+ return (
921
+ f"Error: Analysis failed. {status}",
922
+ None,
923
+ None
924
+ )
925
+
926
+ # Generate visualizations
927
+ prob_fig, conf_fig = generate_clinical_visualization(results)
928
+
929
+ # Generate clinical report
930
+ report = generate_clinical_report(results, quality_score, patient_info)
931
+
932
+ return report, prob_fig, conf_fig
933
+
934
+ except Exception as e:
935
+ logger.error(f"Error in retinal image analysis: {str(e)}")
936
+ return (
937
+ f"Error: Analysis failed due to {str(e)}",
938
+ None,
939
+ None
940
+ )
941
+
942
+ def create_gradio_interface():
943
+ """
944
+ Create Gradio interface for clinical use.
945
+
946
+ Returns:
947
+ Gradio interface object
948
+ """
949
+ with gr.Blocks(title="Clinical Retinal Analysis System") as interface:
950
+ gr.Markdown(
951
+ """
952
+ # Clinical Retinal Analysis System
953
+ AI-assisted retinal screening for medical professionals. Upload a fundus image and enter patient details for comprehensive analysis.
954
+ """
955
+ )
956
+
957
+ with gr.Row():
958
+ with gr.Column(scale=2):
959
+ image_input = gr.Image(type="pil", label="Upload Fundus Image")
960
+ patient_id = gr.Textbox(label="Patient ID")
961
+ patient_age = gr.Textbox(label="Patient Age")
962
+ medical_history = gr.Textbox(label="Medical History", lines=3)
963
+ analyze_button = gr.Button("Analyze Retinal Image")
964
+
965
+ with gr.Column(scale=3):
966
+ report_output = gr.Markdown(label="Clinical Report")
967
+ prob_plot = gr.Plot(label="Detection Probabilities")
968
+ conf_plot = gr.Plot(label="Confidence Intervals & Urgency")
969
+
970
+ analyze_button.click(
971
+ fn=analyze_retinal_image,
972
+ inputs=[image_input, patient_id, patient_age, medical_history],
973
+ outputs=[report_output, prob_plot, conf_plot]
974
+ )
975
+
976
+ return interface
977
+
978
+ # Launch the interface
979
+ if __name__ == "__main__":
980
+ interface = create_gradio_interface()
981
+ interface.launch()