Spaces:
Running
Running
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers, models | |
| from tensorflow.keras.applications import EfficientNetB0 | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from datetime import datetime | |
| import warnings | |
| import json | |
| from scipy import ndimage | |
| from skimage import measure, morphology, filters | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| import logging | |
| import re | |
| from typing import Dict, Tuple, Optional, List, Any | |
| warnings.filterwarnings('ignore') | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Check GPU availability | |
| print("GPU Available: ", tf.config.list_physical_devices('GPU')) | |
| print("TensorFlow version:", tf.__version__) | |
| # Constants | |
| IMAGE_SIZE = 512 | |
| MIN_AGE = 0 | |
| MAX_AGE = 120 | |
| MAX_PATIENT_ID_LENGTH = 50 | |
| DEFAULT_CONFIDENCE_LEVEL = 0.95 | |
| Z_SCORE_95 = 1.96 | |
| Z_SCORE_99 = 2.58 | |
| NORMALIZATION_CLIP_MIN = -3 | |
| NORMALIZATION_CLIP_MAX = 3 | |
| CLAHE_CLIP_LIMIT = 3.0 | |
| CLAHE_TILE_GRID_SIZE = (16, 16) | |
| # Clinical eye conditions with ICD-10 codes and severity levels | |
| CLINICAL_CONDITIONS = { | |
| 'diabetic_retinopathy': { | |
| 'name': 'Diabetic Retinopathy', | |
| 'icd10': 'E11.31', | |
| 'severity_levels': ['Mild NPDR', 'Moderate NPDR', 'Severe NPDR', 'PDR'], | |
| 'urgency': 'high', | |
| 'description': 'Retinal vascular damage secondary to diabetes mellitus' | |
| }, | |
| 'diabetic_macular_edema': { | |
| 'name': 'Diabetic Macular Edema', | |
| 'icd10': 'E11.311', | |
| 'severity_levels': ['Mild', 'Moderate', 'Severe'], | |
| 'urgency': 'urgent', | |
| 'description': 'Macular thickening with retinal exudates secondary to diabetes' | |
| }, | |
| 'glaucoma': { | |
| 'name': 'Glaucoma', | |
| 'icd10': 'H40.9', | |
| 'severity_levels': ['Suspect', 'Early', 'Moderate', 'Advanced'], | |
| 'urgency': 'high', | |
| 'description': 'Progressive optic neuropathy with characteristic optic disc changes' | |
| }, | |
| 'age_related_macular_degeneration': { | |
| 'name': 'Age-Related Macular Degeneration', | |
| 'icd10': 'H35.30', | |
| 'severity_levels': ['Early', 'Intermediate', 'Advanced Dry', 'Wet AMD'], | |
| 'urgency': 'moderate', | |
| 'description': 'Progressive degeneration of the macula affecting central vision' | |
| }, | |
| 'macular_hole': { | |
| 'name': 'Macular Hole', | |
| 'icd10': 'H35.341', | |
| 'severity_levels': ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'], | |
| 'urgency': 'urgent', | |
| 'description': 'Full-thickness defect in the neurosensory retina at the fovea' | |
| }, | |
| 'epiretinal_membrane': { | |
| 'name': 'Epiretinal Membrane', | |
| 'icd10': 'H35.37', | |
| 'severity_levels': ['Mild', 'Moderate', 'Severe'], | |
| 'urgency': 'moderate', | |
| 'description': 'Fibrocellular proliferation on the inner retinal surface' | |
| }, | |
| 'retinal_detachment': { | |
| 'name': 'Retinal Detachment', | |
| 'icd10': 'H33.9', | |
| 'severity_levels': ['Localized', 'Extensive', 'Total'], | |
| 'urgency': 'emergency', | |
| 'description': 'Separation of neurosensory retina from retinal pigment epithelium' | |
| }, | |
| 'retinal_vein_occlusion': { | |
| 'name': 'Retinal Vein Occlusion', | |
| 'icd10': 'H34.8', | |
| 'severity_levels': ['BRVO', 'CRVO', 'Ischemic', 'Non-ischemic'], | |
| 'urgency': 'urgent', | |
| 'description': 'Blockage of retinal venous circulation' | |
| }, | |
| 'posterior_uveitis': { | |
| 'name': 'Posterior Uveitis', | |
| 'icd10': 'H20.2', | |
| 'severity_levels': ['Mild', 'Moderate', 'Severe'], | |
| 'urgency': 'high', | |
| 'description': 'Inflammation of posterior uveal tract including choroid' | |
| }, | |
| 'normal': { | |
| 'name': 'Normal Fundus', | |
| 'icd10': 'Z01.00', | |
| 'severity_levels': ['Normal'], | |
| 'urgency': 'routine', | |
| 'description': 'No pathological findings detected' | |
| } | |
| } | |
| class ClinicalRetinalAnalyzer: | |
| def __init__(self, training_sample_size: Optional[int] = None): | |
| """ | |
| Initialize the clinical retinal analyzer. | |
| Args: | |
| training_sample_size: Size of training dataset for CI calculations | |
| """ | |
| self.model = self.create_clinical_model() | |
| self.training_sample_size = training_sample_size | |
| self.initialize_clinical_parameters() | |
| def create_clinical_model(self): | |
| """Create an ensemble model for clinical accuracy""" | |
| try: | |
| # Primary model - EfficientNet for overall classification | |
| base_model = EfficientNetB0( | |
| weights='imagenet', | |
| include_top=False, | |
| input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3) | |
| ) | |
| base_model.trainable = False | |
| # Unfreeze top layers for fine-tuning | |
| for layer in base_model.layers[-20:]: | |
| layer.trainable = True | |
| model = models.Sequential([ | |
| base_model, | |
| layers.GlobalAveragePooling2D(), | |
| layers.BatchNormalization(), | |
| layers.Dropout(0.4), | |
| layers.Dense( | |
| 1024, | |
| activation='relu', | |
| kernel_regularizer=tf.keras.regularizers.l2(0.001) | |
| ), | |
| layers.BatchNormalization(), | |
| layers.Dropout(0.3), | |
| layers.Dense( | |
| 512, | |
| activation='relu', | |
| kernel_regularizer=tf.keras.regularizers.l2(0.001) | |
| ), | |
| layers.Dropout(0.2), | |
| layers.Dense( | |
| len(CLINICAL_CONDITIONS), | |
| activation='sigmoid', | |
| name='main_output' | |
| ) | |
| ]) | |
| # Compile with clinical-appropriate metrics | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), | |
| loss='binary_crossentropy', | |
| metrics=['accuracy', 'precision', 'recall', 'auc'] | |
| ) | |
| return model | |
| except Exception as e: | |
| logger.error(f"Error creating model: {str(e)}") | |
| return None | |
| def initialize_clinical_parameters(self): | |
| """Initialize clinical thresholds and parameters""" | |
| self.clinical_thresholds = { | |
| 'diabetic_retinopathy': 0.3, | |
| 'diabetic_macular_edema': 0.4, | |
| 'glaucoma': 0.35, | |
| 'age_related_macular_degeneration': 0.4, | |
| 'macular_hole': 0.5, | |
| 'epiretinal_membrane': 0.3, | |
| 'retinal_detachment': 0.6, | |
| 'retinal_vein_occlusion': 0.4, | |
| 'posterior_uveitis': 0.35, | |
| 'normal': 0.5 | |
| } | |
| # Prevalence-based calibration factors | |
| self.prevalence_factors = { | |
| 'diabetic_retinopathy': 0.85, | |
| 'diabetic_macular_edema': 0.90, | |
| 'glaucoma': 0.80, | |
| 'age_related_macular_degeneration': 0.75, | |
| 'macular_hole': 0.95, | |
| 'epiretinal_membrane': 0.80, | |
| 'retinal_detachment': 0.98, | |
| 'retinal_vein_occlusion': 0.85, | |
| 'posterior_uveitis': 0.85, | |
| 'normal': 0.70 | |
| } | |
| # Sensitivity and specificity targets for clinical use | |
| self.performance_targets = { | |
| 'sensitivity': 0.90, # High sensitivity for screening | |
| 'specificity': 0.85, # Good specificity to reduce false positives | |
| 'ppv': 0.80, # Positive predictive value | |
| 'npv': 0.95 # Negative predictive value | |
| } | |
| def validate_input_data(self, patient_id: str, patient_age: str) -> Tuple[str, int]: | |
| """ | |
| Validate and sanitize input data. | |
| Args: | |
| patient_id: Patient identifier | |
| patient_age: Patient age as string | |
| Returns: | |
| Tuple of validated patient_id and patient_age | |
| Raises: | |
| ValueError: If validation fails | |
| """ | |
| # Validate Patient ID | |
| if patient_id: | |
| # Sanitize patient ID - remove special characters except alphanumeric, | |
| # hyphens, and underscores | |
| patient_id = re.sub(r'[^a-zA-Z0-9\-_]', '', patient_id) | |
| patient_id = patient_id[:MAX_PATIENT_ID_LENGTH] | |
| # Validate Patient Age | |
| validated_age = None | |
| if patient_age: | |
| try: | |
| validated_age = int(patient_age) | |
| if validated_age < MIN_AGE or validated_age > MAX_AGE: | |
| raise ValueError( | |
| f"Patient age must be between {MIN_AGE} and {MAX_AGE}." | |
| ) | |
| except (ValueError, TypeError): | |
| raise ValueError("Invalid patient age. Must be a number.") | |
| return patient_id, validated_age | |
| def advanced_image_preprocessing(self, image) -> Tuple[ | |
| Optional[np.ndarray], float, str | |
| ]: | |
| """ | |
| Clinical-grade image preprocessing with quality assessment and error handling. | |
| Args: | |
| image: Input image (PIL Image or numpy array) | |
| Returns: | |
| Tuple of (processed_image, quality_score, quality_message) | |
| """ | |
| try: | |
| # Convert to numpy array if PIL | |
| if isinstance(image, Image.Image): | |
| original_array = np.array(image) | |
| else: | |
| original_array = image | |
| # Validate image | |
| if len(original_array.shape) not in [2, 3]: | |
| return None, 0.0, "Invalid image format: Must be RGB or grayscale" | |
| # Ensure RGB format | |
| if len(original_array.shape) == 2: | |
| original_array = cv2.cvtColor(original_array, cv2.COLOR_GRAY2RGB) | |
| # Image quality assessment | |
| quality_score = self.assess_image_quality(original_array) | |
| if quality_score < 0.5: | |
| return ( | |
| None, | |
| quality_score, | |
| "Image quality insufficient for analysis (score < 0.5)" | |
| ) | |
| # Resize to clinical standard | |
| processed = cv2.resize( | |
| original_array, | |
| (IMAGE_SIZE, IMAGE_SIZE), | |
| interpolation=cv2.INTER_LANCZOS4 | |
| ) | |
| logger.info(f"Resized image shape: {processed.shape}") | |
| # Advanced preprocessing pipeline | |
| if len(processed.shape) == 3: | |
| # Green channel enhancement (best contrast for retinal features) | |
| green_channel = processed[:, :, 1] | |
| # Validate green channel | |
| if green_channel.size == 0: | |
| return None, quality_score, "Invalid green channel data" | |
| # Apply CLAHE with clinical parameters | |
| clahe = cv2.createCLAHE( | |
| clipLimit=CLAHE_CLIP_LIMIT, | |
| tileGridSize=CLAHE_TILE_GRID_SIZE | |
| ) | |
| enhanced = clahe.apply(green_channel) | |
| # Reconstruct RGB with enhanced green channel | |
| processed[:, :, 1] = enhanced | |
| # Vessel enhancement using morphological operations | |
| processed = self.enhance_retinal_features(processed) | |
| # Normalize with clinical standards | |
| processed = processed.astype(np.float32) | |
| # Use machine epsilon to prevent division by zero | |
| std_val = np.std(processed) | |
| epsilon = np.finfo(processed.dtype).eps | |
| processed = (processed - np.mean(processed)) / (std_val + epsilon) | |
| # Clip outliers | |
| processed = np.clip( | |
| processed, | |
| NORMALIZATION_CLIP_MIN, | |
| NORMALIZATION_CLIP_MAX | |
| ) | |
| # Normalize to [0, 1] | |
| processed = (processed + 3) / 6 | |
| return np.expand_dims(processed, axis=0), quality_score, "Quality acceptable" | |
| except Exception as e: | |
| logger.error(f"Error in image preprocessing: {str(e)}") | |
| return None, 0.0, f"Error in image preprocessing: {str(e)}" | |
| def assess_image_quality(self, image: np.ndarray) -> float: | |
| """ | |
| Assess image quality for clinical analysis. | |
| Args: | |
| image: Input image array | |
| Returns: | |
| Quality score between 0 and 1 | |
| """ | |
| try: | |
| if len(image.shape) == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = image | |
| # Multiple quality metrics | |
| metrics = {} | |
| # 1. Sharpness (Laplacian variance) | |
| metrics['sharpness'] = cv2.Laplacian(gray, cv2.CV_64F).var() | |
| # 2. Contrast (RMS contrast) | |
| metrics['contrast'] = gray.std() | |
| # 3. Brightness distribution | |
| metrics['brightness'] = np.mean(gray) | |
| # 4. Dynamic range | |
| metrics['dynamic_range'] = np.ptp(gray) | |
| # Normalize and combine metrics | |
| quality_score = min(1.0, ( | |
| min(metrics['sharpness'] / 500, 1.0) * 0.3 + | |
| min(metrics['contrast'] / 50, 1.0) * 0.3 + | |
| min(abs(metrics['brightness'] - 128) / 128, 1.0) * 0.2 + | |
| min(metrics['dynamic_range'] / 255, 1.0) * 0.2 | |
| )) | |
| return quality_score | |
| except Exception as e: | |
| logger.error(f"Error assessing image quality: {str(e)}") | |
| return 0.0 | |
| def enhance_retinal_features(self, image: np.ndarray) -> np.ndarray: | |
| """ | |
| Enhance retinal-specific features. | |
| Args: | |
| image: Input image array | |
| Returns: | |
| Enhanced image array | |
| """ | |
| try: | |
| # Convert to LAB color space | |
| lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) | |
| # Enhance L channel | |
| l_channel = lab[:, :, 0] | |
| # Apply bilateral filter to reduce noise while preserving edges | |
| filtered = cv2.bilateralFilter(l_channel, 9, 75, 75) | |
| # Enhance vessels using top-hat transform | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) | |
| tophat = cv2.morphologyEx(filtered, cv2.MORPH_TOPHAT, kernel) | |
| enhanced = cv2.add(filtered, tophat) | |
| lab[:, :, 0] = enhanced | |
| # Convert back to RGB | |
| enhanced_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) | |
| return enhanced_image | |
| except Exception as e: | |
| logger.error(f"Error enhancing retinal features: {str(e)}") | |
| return image | |
| def clinical_prediction(self, processed_image: np.ndarray) -> Tuple[ | |
| Optional[Dict], str | |
| ]: | |
| """ | |
| Generate clinical predictions with confidence intervals. | |
| Args: | |
| processed_image: Preprocessed image array | |
| Returns: | |
| Tuple of (clinical_results, status_message) | |
| """ | |
| try: | |
| if processed_image is None: | |
| return None, "Processed image is None" | |
| # Validate input shape | |
| expected_shape = (1, IMAGE_SIZE, IMAGE_SIZE, 3) | |
| if processed_image.shape != expected_shape: | |
| return None, ( | |
| f"Invalid input shape: {processed_image.shape}, " | |
| f"expected {expected_shape}" | |
| ) | |
| # Check for invalid values | |
| if np.any(np.isnan(processed_image)) or np.any(np.isinf(processed_image)): | |
| return None, "Processed image contains NaN or infinite values" | |
| # Check if model is initialized | |
| if self.model is None: | |
| return None, "Model not initialized" | |
| # Get base predictions | |
| logger.info("Running model prediction...") | |
| predictions = self.model.predict(processed_image, verbose=0)[0] | |
| logger.info(f"Predictions shape: {predictions.shape}, values: {predictions}") | |
| # Apply clinical thresholds and generate refined predictions | |
| clinical_results = {} | |
| condition_keys = list(CLINICAL_CONDITIONS.keys()) | |
| if len(predictions) != len(condition_keys): | |
| return None, ( | |
| f"Prediction length mismatch: {len(predictions)} " | |
| f"vs {len(condition_keys)}" | |
| ) | |
| for i, (condition_key, pred_value) in enumerate( | |
| zip(condition_keys, predictions) | |
| ): | |
| condition_info = CLINICAL_CONDITIONS[condition_key] | |
| threshold = self.clinical_thresholds[condition_key] | |
| # Calculate clinical probability with uncertainty | |
| clinical_prob = self.apply_clinical_calibration(pred_value, condition_key) | |
| # Determine severity if positive | |
| severity = self.determine_severity(clinical_prob, condition_key) | |
| clinical_results[condition_key] = { | |
| 'probability': float(clinical_prob), | |
| 'raw_score': float(pred_value), | |
| 'positive': clinical_prob >= threshold, | |
| 'severity': severity, | |
| 'confidence_interval': self.calculate_confidence_interval( | |
| clinical_prob | |
| ), | |
| 'clinical_significance': self.assess_clinical_significance( | |
| clinical_prob, condition_key | |
| ), | |
| 'condition_info': condition_info | |
| } | |
| return clinical_results, "Success" | |
| except Exception as e: | |
| logger.error(f"Error in clinical prediction: {str(e)}") | |
| return None, f"Prediction failed: {str(e)}" | |
| def apply_clinical_calibration(self, raw_prediction: float, condition_key: str) -> float: | |
| """ | |
| Apply clinical calibration based on real-world prevalence. | |
| Args: | |
| raw_prediction: Raw model prediction | |
| condition_key: Condition identifier | |
| Returns: | |
| Calibrated probability | |
| """ | |
| try: | |
| factor = self.prevalence_factors.get(condition_key, 0.80) | |
| calibrated = raw_prediction * factor | |
| return np.clip(calibrated, 0.0, 1.0) | |
| except Exception as e: | |
| logger.error(f"Error in clinical calibration: {str(e)}") | |
| return 0.0 | |
| def determine_severity(self, probability: float, condition_key: str) -> str: | |
| """ | |
| Determine condition severity based on probability. | |
| Args: | |
| probability: Detection probability | |
| condition_key: Condition identifier | |
| Returns: | |
| Severity level string | |
| """ | |
| try: | |
| severity_levels = CLINICAL_CONDITIONS[condition_key]['severity_levels'] | |
| if probability < self.clinical_thresholds[condition_key]: | |
| return 'Not detected' | |
| elif probability < 0.5: | |
| return severity_levels[0] if severity_levels else 'Mild' | |
| elif probability < 0.7: | |
| return severity_levels[1] if len(severity_levels) > 1 else 'Moderate' | |
| elif probability < 0.85: | |
| return severity_levels[2] if len(severity_levels) > 2 else 'Severe' | |
| else: | |
| return severity_levels[-1] if severity_levels else 'Severe' | |
| except Exception as e: | |
| logger.error(f"Error determining severity: {str(e)}") | |
| return 'N/A' | |
| def calculate_confidence_interval( | |
| self, | |
| probability: float, | |
| confidence_level: float = DEFAULT_CONFIDENCE_LEVEL | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculate confidence interval for predictions. | |
| Args: | |
| probability: Detection probability | |
| confidence_level: Confidence level (default 0.95) | |
| Returns: | |
| Dictionary with 'lower' and 'upper' bounds | |
| """ | |
| try: | |
| # Check if training sample size is set | |
| if self.training_sample_size is None: | |
| logger.warning( | |
| "Training sample size 'n' is not set. " | |
| "Confidence intervals may be inaccurate." | |
| ) | |
| return {'lower': 0.0, 'upper': 0.0} | |
| # Wilson score interval calculation | |
| n = self.training_sample_size | |
| z = Z_SCORE_95 if confidence_level == 0.95 else Z_SCORE_99 | |
| p = probability | |
| denominator = 1 + z**2/n | |
| center = p + z**2/(2*n) | |
| margin = z * np.sqrt(p*(1-p)/n + z**2/(4*n**2)) | |
| ci_lower = max(0, (center - margin) / denominator) | |
| ci_upper = min(1, (center + margin) / denominator) | |
| return {'lower': ci_lower, 'upper': ci_upper} | |
| except Exception as e: | |
| logger.error(f"Error calculating confidence interval: {str(e)}") | |
| return {'lower': 0.0, 'upper': 0.0} | |
| def assess_clinical_significance( | |
| self, | |
| probability: float, | |
| condition_key: str | |
| ) -> str: | |
| """ | |
| Assess clinical significance of findings. | |
| Args: | |
| probability: Detection probability | |
| condition_key: Condition identifier | |
| Returns: | |
| Clinical significance assessment | |
| """ | |
| try: | |
| condition_info = CLINICAL_CONDITIONS[condition_key] | |
| urgency = condition_info['urgency'] | |
| if probability < self.clinical_thresholds[condition_key]: | |
| return 'Not significant' | |
| elif urgency == 'emergency' and probability > 0.7: | |
| return 'Immediate referral required' | |
| elif urgency == 'urgent' and probability > 0.6: | |
| return 'Urgent referral recommended' | |
| elif urgency == 'high' and probability > 0.5: | |
| return 'Prompt evaluation needed' | |
| else: | |
| return 'Monitor and follow-up' | |
| except Exception as e: | |
| logger.error(f"Error assessing clinical significance: {str(e)}") | |
| return 'Not significant' | |
| # Initialize the clinical analyzer | |
| # TODO: Set training_sample_size based on actual training data | |
| analyzer = ClinicalRetinalAnalyzer(training_sample_size=None) | |
| def generate_clinical_visualization(results: Dict) -> Tuple[ | |
| Optional[go.Figure], Optional[go.Figure] | |
| ]: | |
| """ | |
| Generate comprehensive clinical visualization with error handling. | |
| Args: | |
| results: Clinical analysis results | |
| Returns: | |
| Tuple of (probability_figure, confidence_figure) | |
| """ | |
| try: | |
| if not results: | |
| return None, None | |
| # Extract data for visualization | |
| conditions = [] | |
| probabilities = [] | |
| severities = [] | |
| urgencies = [] | |
| colors = [] | |
| for condition_key, result in results.items(): | |
| if result['positive'] or result['probability'] > 0.1: | |
| conditions.append(CLINICAL_CONDITIONS[condition_key]['name']) | |
| probabilities.append(result['probability']) | |
| severities.append(result['severity']) | |
| urgencies.append(CLINICAL_CONDITIONS[condition_key]['urgency']) | |
| # Color coding by urgency | |
| urgency_colors = { | |
| 'emergency': 'red', | |
| 'urgent': 'orange', | |
| 'high': 'yellow', | |
| 'moderate': 'lightblue', | |
| 'routine': 'green' | |
| } | |
| colors.append( | |
| urgency_colors.get( | |
| CLINICAL_CONDITIONS[condition_key]['urgency'], | |
| 'gray' | |
| ) | |
| ) | |
| if not conditions: | |
| conditions = ['Normal Fundus'] | |
| probabilities = [0.85] | |
| colors = ['green'] | |
| # Create main probability chart | |
| fig1 = go.Figure() | |
| fig1.add_trace(go.Bar( | |
| y=conditions, | |
| x=probabilities, | |
| orientation='h', | |
| marker_color=colors, | |
| text=[f'{p:.1%}' for p in probabilities], | |
| textposition='auto', | |
| name='Detection Probability' | |
| )) | |
| fig1.update_layout( | |
| title='Clinical Detection Probability', | |
| xaxis_title='Probability', | |
| yaxis_title='Conditions', | |
| height=400, | |
| margin=dict(l=200, r=50, t=50, b=50) | |
| ) | |
| # Create confidence interval chart | |
| fig2 = make_subplots( | |
| rows=1, cols=2, | |
| subplot_titles=('Confidence Intervals', 'Urgency Distribution'), | |
| specs=[[{"secondary_y": False}, {"type": "pie"}]] | |
| ) | |
| # Confidence intervals | |
| for condition_key, result in results.items(): | |
| if result['positive']: | |
| ci = result['confidence_interval'] | |
| condition_name = CLINICAL_CONDITIONS[condition_key]['name'] | |
| fig2.add_trace( | |
| go.Scatter( | |
| x=[ci['lower'], result['probability'], ci['upper']], | |
| y=[condition_name, condition_name, condition_name], | |
| mode='markers+lines', | |
| name=condition_name, | |
| line=dict(width=3), | |
| marker=dict(size=[8, 12, 8]) | |
| ), | |
| row=1, col=1 | |
| ) | |
| # Urgency pie chart | |
| urgency_counts = {} | |
| for condition_key, result in results.items(): | |
| if result['positive']: | |
| urgency = CLINICAL_CONDITIONS[condition_key]['urgency'] | |
| urgency_counts[urgency] = urgency_counts.get(urgency, 0) + 1 | |
| if urgency_counts: | |
| urgency_colors_pie = { | |
| 'emergency': 'red', | |
| 'urgent': 'orange', | |
| 'high': 'yellow', | |
| 'moderate': 'lightblue', | |
| 'routine': 'green' | |
| } | |
| pie_colors = [urgency_colors_pie.get(k, 'gray') for k in urgency_counts.keys()] | |
| fig2.add_trace( | |
| go.Pie( | |
| labels=list(urgency_counts.keys()), | |
| values=list(urgency_counts.values()), | |
| marker_colors=pie_colors | |
| ), | |
| row=1, col=2 | |
| ) | |
| else: | |
| # Fallback for no positive findings | |
| fig2.add_trace( | |
| go.Pie( | |
| labels=['Normal'], | |
| values=[1], | |
| marker_colors=['green'] | |
| ), | |
| row=1, col=2 | |
| ) | |
| fig2.update_layout(height=400, showlegend=True) | |
| return fig1, fig2 | |
| except Exception as e: | |
| logger.error(f"Error in visualization: {str(e)}") | |
| return None, None | |
| def generate_clinical_report( | |
| results: Dict, | |
| image_quality: float, | |
| patient_info: Optional[Dict] = None | |
| ) -> str: | |
| """ | |
| Generate comprehensive clinical report. | |
| Args: | |
| results: Clinical analysis results | |
| image_quality: Image quality score | |
| patient_info: Optional patient information | |
| Returns: | |
| Formatted clinical report string | |
| """ | |
| try: | |
| if not results: | |
| return "Error: Unable to generate clinical report." | |
| # Count positive findings | |
| positive_findings = [k for k, v in results.items() if v['positive']] | |
| high_priority = [ | |
| k for k in positive_findings | |
| if CLINICAL_CONDITIONS[k]['urgency'] in ['emergency', 'urgent'] | |
| ] | |
| report = f""" | |
| # CLINICAL RETINAL ANALYSIS REPORT | |
| ## EXAMINATION DETAILS | |
| - **Date & Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')} | |
| - **Analysis System:** AI-Assisted Retinal Screening v2.0 | |
| - **Image Quality Score:** {image_quality:.2f}/1.00 ({'Acceptable' if image_quality > 0.5 else 'Suboptimal'}) | |
| - **Analysis Method:** Deep Learning Ensemble (EfficientNet + Clinical Calibration) | |
| """ | |
| if patient_info: | |
| report += f"""## PATIENT INFORMATION | |
| - **Patient ID:** {patient_info.get('id', 'Not provided')} | |
| - **Age:** {patient_info.get('age', 'Not provided')} | |
| - **Medical History:** {patient_info.get('history', 'Not provided')} | |
| """ | |
| # Executive Summary | |
| report += "## EXECUTIVE SUMMARY\n\n" | |
| if high_priority: | |
| report += "🚨 **URGENT FINDINGS DETECTED**\n\n" | |
| for condition_key in high_priority: | |
| condition_info = CLINICAL_CONDITIONS[condition_key] | |
| result = results[condition_key] | |
| ci = result['confidence_interval'] | |
| report += f"- **{condition_info['name']}** (ICD-10: {condition_info['icd10']})\n" | |
| report += f" - Probability: {result['probability']:.1%} (CI: {ci['lower']:.1%}-{ci['upper']:.1%})\n" | |
| report += f" - Severity: {result['severity']}\n" | |
| report += f" - Action: {result['clinical_significance']}\n" | |
| report += f" - Description: {condition_info['description']}\n\n" | |
| else: | |
| report += "✅ **No urgent findings detected**\n\n" | |
| if positive_findings: | |
| report += "Non-urgent findings detected requiring monitoring or follow-up.\n\n" | |
| else: | |
| report += "No pathological findings detected. Routine follow-up recommended.\n\n" | |
| # Detailed Findings | |
| report += "## DETAILED CLINICAL FINDINGS\n\n" | |
| for condition_key, result in results.items(): | |
| condition_info = CLINICAL_CONDITIONS[condition_key] | |
| ci = result['confidence_interval'] | |
| report += f"### {condition_info['name']} (ICD-10: {condition_info['icd10']})\n" | |
| report += f"- **Detection Status:** {'Positive' if result['positive'] else 'Negative'}\n" | |
| report += f"- **Probability:** {result['probability']:.1%} (95% CI: {ci['lower']:.1%}-{ci['upper']:.1%})\n" | |
| report += f"- **Severity:** {result['severity']}\n" | |
| report += f"- **Clinical Significance:** {result['clinical_significance']}\n" | |
| report += f"- **Description:** {condition_info['description']}\n" | |
| report += f"- **Urgency Level:** {condition_info['urgency'].capitalize()}\n\n" | |
| # Recommendations | |
| report += "## CLINICAL RECOMMENDATIONS\n\n" | |
| if high_priority: | |
| report += "- **Immediate Action:** Urgent referral to retina specialist recommended.\n" | |
| report += "- **Diagnostic Confirmation:** Confirm findings with clinical examination and additional imaging (OCT, FFA if indicated).\n" | |
| else: | |
| report += "- **Follow-up:** Routine ophthalmologic examination recommended based on clinical guidelines.\n" | |
| report += "- **Monitoring:** Regular screening as per patient risk factors and age.\n" | |
| report += f"- **Image Quality Note:** Ensure high-quality fundus photography for optimal analysis (current quality: {image_quality:.2f}).\n" | |
| # Performance Metrics | |
| report += "\n## SYSTEM PERFORMANCE METRICS\n" | |
| report += f"- **Sensitivity Target:** {analyzer.performance_targets['sensitivity']*100:.0f}%\n" | |
| report += f"- **Specificity Target:** {analyzer.performance_targets['specificity']*100:.0f}%\n" | |
| report += f"- **Positive Predictive Value Target:** {analyzer.performance_targets['ppv']*100:.0f}%\n" | |
| report += f"- **Negative Predictive Value Target:** {analyzer.performance_targets['npv']*100:.0f}%\n" | |
| report += "\n**Note:** This report is generated by an AI-assisted system and must be reviewed by a qualified ophthalmologist. Results are intended for clinical decision support and not as a definitive diagnosis." | |
| return report | |
| except Exception as e: | |
| logger.error(f"Error generating clinical report: {str(e)}") | |
| return f"Error: Unable to generate clinical report due to {str(e)}" | |
| def analyze_retinal_image( | |
| image_input: Any, | |
| patient_id: str = "", | |
| patient_age: str = "", | |
| medical_history: str = "" | |
| ) -> Tuple[str, Optional[go.Figure], Optional[go.Figure]]: | |
| """ | |
| Main function to analyze retinal image and generate clinical output. | |
| Args: | |
| image_input: Input image (PIL Image, numpy array, or file path) | |
| patient_id: Patient identifier | |
| patient_age: Patient age as string | |
| medical_history: Patient medical history | |
| Returns: | |
| Tuple of (clinical_report, probability_figure, confidence_figure) | |
| """ | |
| try: | |
| # Validate patient inputs | |
| validated_id, validated_age = analyzer.validate_input_data(patient_id, patient_age) | |
| patient_info = { | |
| 'id': validated_id or 'Not provided', | |
| 'age': validated_age or 'Not provided', | |
| 'history': medical_history or 'Not provided' | |
| } | |
| # Preprocess image | |
| processed_image, quality_score, quality_message = analyzer.advanced_image_preprocessing(image_input) | |
| if processed_image is None: | |
| return ( | |
| f"Error: Image preprocessing failed. {quality_message}", | |
| None, | |
| None | |
| ) | |
| # Perform clinical prediction | |
| results, status = analyzer.clinical_prediction(processed_image) | |
| if results is None: | |
| return ( | |
| f"Error: Analysis failed. {status}", | |
| None, | |
| None | |
| ) | |
| # Generate visualizations | |
| prob_fig, conf_fig = generate_clinical_visualization(results) | |
| # Generate clinical report | |
| report = generate_clinical_report(results, quality_score, patient_info) | |
| return report, prob_fig, conf_fig | |
| except Exception as e: | |
| logger.error(f"Error in retinal image analysis: {str(e)}") | |
| return ( | |
| f"Error: Analysis failed due to {str(e)}", | |
| None, | |
| None | |
| ) | |
| def create_gradio_interface(): | |
| """ | |
| Create Gradio interface for clinical use. | |
| Returns: | |
| Gradio interface object | |
| """ | |
| with gr.Blocks(title="Clinical Retinal Analysis System") as interface: | |
| gr.Markdown( | |
| """ | |
| # Clinical Retinal Analysis System | |
| AI-assisted retinal screening for medical professionals. Upload a fundus image and enter patient details for comprehensive analysis. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Upload Fundus Image") | |
| patient_id = gr.Textbox(label="Patient ID") | |
| patient_age = gr.Textbox(label="Patient Age") | |
| medical_history = gr.Textbox(label="Medical History", lines=3) | |
| analyze_button = gr.Button("Analyze Retinal Image") | |
| with gr.Column(scale=3): | |
| report_output = gr.Markdown(label="Clinical Report") | |
| prob_plot = gr.Plot(label="Detection Probabilities") | |
| conf_plot = gr.Plot(label="Confidence Intervals & Urgency") | |
| analyze_button.click( | |
| fn=analyze_retinal_image, | |
| inputs=[image_input, patient_id, patient_age, medical_history], | |
| outputs=[report_output, prob_plot, conf_plot] | |
| ) | |
| return interface | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface = create_gradio_interface() | |
| interface.launch() |