| """
|
| Diabetic Retinopathy Classification & Segmentation β Flask Backend
|
| Models and preprocessing are exact replicas from the training notebooks.
|
| """
|
|
|
| import os
|
| import io
|
| import base64
|
| import cv2
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torchvision import transforms, models
|
| from PIL import Image
|
| from flask import Flask, request, jsonify, send_from_directory
|
| from flask_cors import CORS
|
|
|
| try:
|
| import segmentation_models_pytorch as smp
|
| except ImportError:
|
| print("ERROR: segmentation_models_pytorch not installed. Run: pip install segmentation-models-pytorch")
|
| raise
|
|
|
|
|
|
|
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| PARENT_DIR = os.path.dirname(BASE_DIR)
|
|
|
| CL_WEIGHTS = os.path.join(BASE_DIR, "EfficientNetB4_model.pth")
|
| SEG_WEIGHTS = os.path.join(BASE_DIR, "best_effnet_unet.pth")
|
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| CL_IMG_SIZE = 384
|
| SEG_IMG_SIZE = 512
|
| SEG_CLASSES = ["Background", "Haemorrhages (HE)", "Hard Exudates (EX)", "Soft Exudates (SE)"]
|
|
|
| DR_GRADE_LABELS = {
|
| 0: "No DR",
|
| 1: "Mild NPDR",
|
| 2: "Moderate NPDR",
|
| 3: "Severe NPDR",
|
| 4: "Proliferative DR"
|
| }
|
|
|
| DR_GRADE_DESCRIPTIONS = {
|
| 0: "No signs of diabetic retinopathy detected. The retina appears healthy with no visible abnormalities.",
|
| 1: "Mild non-proliferative diabetic retinopathy. Minor microaneurysms may be present.",
|
| 2: "Moderate non-proliferative diabetic retinopathy. Microaneurysms, dot/blot hemorrhages, and hard exudates may be visible.",
|
| 3: "Severe non-proliferative diabetic retinopathy. Extensive hemorrhages, venous beading, and intraretinal microvascular abnormalities present.",
|
| 4: "Proliferative diabetic retinopathy. Neovascularization and/or vitreous/preretinal hemorrhage detected. Urgent referral recommended."
|
| }
|
|
|
|
|
|
|
|
|
|
|
| class APTOSModel(nn.Module):
|
| """Classification model β exact architecture from efficientnet-b4.ipynb"""
|
| def __init__(self, num_classes=5):
|
| super().__init__()
|
| self.backbone = models.efficientnet_b4(weights=None)
|
| in_features = self.backbone.classifier[1].in_features
|
| self.backbone.classifier = nn.Sequential(
|
| nn.Dropout(p=0.4),
|
| nn.Linear(in_features, num_classes)
|
| )
|
|
|
| def forward(self, x):
|
| return self.backbone(x)
|
|
|
|
|
| def get_classification_model(path, device):
|
| """Load classification model β exact from pipeline notebook"""
|
| model = APTOSModel()
|
| model.load_state_dict(torch.load(path, map_location=device))
|
| return model.to(device).eval()
|
|
|
|
|
| def get_segmentation_model(path, device):
|
| """Load segmentation model β exact from pipeline notebook"""
|
| model = smp.Unet(
|
| encoder_name="efficientnet-b3",
|
| encoder_weights=None,
|
| in_channels=3,
|
| classes=4
|
| )
|
| model.load_state_dict(torch.load(path, map_location=device))
|
| return model.to(device).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def crop_fundus(image, threshold=10):
|
| """Crop black borders from fundus image β exact from pipeline notebook"""
|
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| _, thresh = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
|
| coords = cv2.findNonZero(thresh)
|
| if coords is not None:
|
| x, y, w, h = cv2.boundingRect(coords)
|
| return image[y:y+h, x:x+w], (x, y, w, h)
|
| return image, (0, 0, image.shape[1], image.shape[0])
|
|
|
|
|
| def apply_clahe_green(image):
|
| """CLAHE on green channel β exact from pipeline notebook"""
|
| green = image[:, :, 1]
|
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| g = clahe.apply(green)
|
| return np.stack([g, g, g], axis=-1)
|
|
|
|
|
| def overlay_mask(image, mask, alpha=0.5):
|
| """Overlay colored segmentation mask β exact from pipeline notebook"""
|
| overlay = image.copy().astype(np.float64)
|
| colors = {1: [255, 0, 0], 2: [0, 255, 0], 3: [0, 0, 255]}
|
| for cls, color in colors.items():
|
| mask_region = mask == cls
|
| if mask_region.any():
|
| overlay[mask_region] = (1 - alpha) * overlay[mask_region] + alpha * np.array(color)
|
| return overlay.astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def run_classification(model, image_rgb):
|
| """Run classification β exact preprocessing from pipeline notebook"""
|
| cl_trans = transforms.Compose([
|
| transforms.Resize((CL_IMG_SIZE, CL_IMG_SIZE)),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| ])
|
| cl_inp = cl_trans(Image.fromarray(image_rgb)).unsqueeze(0).to(DEVICE)
|
| logits = model(cl_inp)
|
| probabilities = F.softmax(logits, dim=1).cpu().numpy()[0]
|
| grade = int(np.argmax(probabilities))
|
| return grade, probabilities
|
|
|
|
|
| @torch.no_grad()
|
| def run_segmentation(model, image_rgb):
|
| """Run segmentation β exact preprocessing from pipeline notebook"""
|
| cropped, coords = crop_fundus(image_rgb)
|
| roi = cv2.resize(cropped, (SEG_IMG_SIZE, SEG_IMG_SIZE))
|
| seg_inp = torch.tensor(
|
| np.transpose(apply_clahe_green(roi).astype(np.float32) / 255.0, (2, 0, 1))
|
| ).to(DEVICE)
|
|
|
|
|
| out_map = torch.zeros((4, 512, 512), device=DEVICE)
|
| cnt_map = torch.zeros((512, 512), device=DEVICE)
|
| for y in range(0, 512 - 256 + 1, 128):
|
| for x in range(0, 512 - 256 + 1, 128):
|
| patch = seg_inp[:, y:y+256, x:x+256].unsqueeze(0)
|
| out_map[:, y:y+256, x:x+256] += model(patch)[0]
|
| cnt_map[y:y+256, x:x+256] += 1
|
|
|
| pred_mask = torch.argmax(out_map / cnt_map.unsqueeze(0), dim=0).cpu().numpy().astype(np.uint8)
|
|
|
| return roi, pred_mask
|
|
|
|
|
| def numpy_to_base64(img_array):
|
| """Convert numpy RGB image to base64 PNG string"""
|
| img = Image.fromarray(img_array)
|
| buffer = io.BytesIO()
|
| img.save(buffer, format='PNG')
|
| return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
|
|
|
| def calculate_lesion_stats(mask):
|
| """Calculate pixel-level statistics for each lesion type"""
|
| total_pixels = mask.size
|
| stats = {}
|
| for i in range(1, 4):
|
| count = int(np.sum(mask == i))
|
| stats[SEG_CLASSES[i]] = {
|
| "pixel_count": count,
|
| "percentage": round(count / total_pixels * 100, 4),
|
| "detected": count > 0
|
| }
|
| return stats
|
|
|
|
|
|
|
|
|
|
|
|
|
| app = Flask(__name__)
|
| CORS(app)
|
|
|
| cl_model = None
|
| seg_model = None
|
|
|
| def load_models():
|
| global cl_model, seg_model
|
| if cl_model is None:
|
| try:
|
| print(f"Using device: {DEVICE}")
|
| print(f"Loading classification model from: {CL_WEIGHTS}")
|
| cl_model = get_classification_model(CL_WEIGHTS, DEVICE)
|
| print("Classification model loaded β
")
|
| except Exception as e:
|
| print(f"Error loading classification model: {e}")
|
| raise
|
|
|
| if seg_model is None:
|
| try:
|
| print(f"Loading segmentation model from: {SEG_WEIGHTS}")
|
| seg_model = get_segmentation_model(SEG_WEIGHTS, DEVICE)
|
| print("Segmentation model loaded β
")
|
| except Exception as e:
|
| print(f"Error loading segmentation model: {e}")
|
| raise
|
|
|
|
|
| @app.route('/')
|
| def index():
|
| return jsonify({"message": "Retina-AI Microservice Running"})
|
|
|
|
|
| @app.route('/api/analyze', methods=['POST'])
|
| def analyze():
|
| """Main analysis endpoint β runs both classification and segmentation"""
|
| load_models()
|
| if 'image' not in request.files:
|
| return jsonify({'error': 'No image file provided'}), 400
|
|
|
| file = request.files['image']
|
| if file.filename == '':
|
| return jsonify({'error': 'No file selected'}), 400
|
|
|
| try:
|
|
|
| file_bytes = np.frombuffer(file.read(), np.uint8)
|
| img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| if img_bgr is None:
|
| return jsonify({'error': 'Could not decode image. Please upload a valid image file.'}), 400
|
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
|
| grade, probabilities = run_classification(cl_model, img_rgb)
|
|
|
|
|
| roi, pred_mask = run_segmentation(seg_model, img_rgb)
|
|
|
|
|
| overlay = overlay_mask(roi, pred_mask)
|
|
|
|
|
| lesion_stats = calculate_lesion_stats(pred_mask)
|
|
|
|
|
| any_lesions = any(s["detected"] for s in lesion_stats.values())
|
|
|
|
|
| response = {
|
| 'success': True,
|
| 'classification': {
|
| 'grade': grade,
|
| 'label': DR_GRADE_LABELS[grade],
|
| 'description': DR_GRADE_DESCRIPTIONS[grade],
|
| 'confidence': {
|
| DR_GRADE_LABELS[i]: round(float(probabilities[i]) * 100, 2)
|
| for i in range(5)
|
| }
|
| },
|
| 'segmentation': {
|
| 'overlay_image': numpy_to_base64(overlay),
|
| 'original_roi': numpy_to_base64(roi),
|
| 'lesion_stats': lesion_stats,
|
| 'any_lesions_detected': any_lesions
|
| }
|
| }
|
|
|
| return jsonify(response)
|
|
|
| except Exception as e:
|
| import traceback
|
| traceback.print_exc()
|
| return jsonify({'error': f'Analysis failed: {str(e)}'}), 500
|
|
|
|
|
| @app.route('/api/health', methods=['GET'])
|
| def health():
|
| return jsonify({
|
| 'status': 'ok',
|
| 'device': str(DEVICE),
|
| 'models_loaded': cl_model is not None and seg_model is not None
|
| })
|
|
|
|
|
| if __name__ == '__main__':
|
| app.run(host='0.0.0.0', port=5001, debug=False)
|
|
|