Retina-AI / app.py
Achyuth12's picture
Upload 5 files
b342dd4 verified
"""
Diabetic Retinopathy Classification & Segmentation β€” Flask Backend
Models and preprocessing are exact replicas from the training notebooks.
"""
import os
import io
import base64
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
try:
import segmentation_models_pytorch as smp
except ImportError:
print("ERROR: segmentation_models_pytorch not installed. Run: pip install segmentation-models-pytorch")
raise
# ============================================================
# CONFIGURATION
# ============================================================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(BASE_DIR)
CL_WEIGHTS = os.path.join(BASE_DIR, "EfficientNetB4_model.pth")
SEG_WEIGHTS = os.path.join(BASE_DIR, "best_effnet_unet.pth")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CL_IMG_SIZE = 384
SEG_IMG_SIZE = 512
SEG_CLASSES = ["Background", "Haemorrhages (HE)", "Hard Exudates (EX)", "Soft Exudates (SE)"]
DR_GRADE_LABELS = {
0: "No DR",
1: "Mild NPDR",
2: "Moderate NPDR",
3: "Severe NPDR",
4: "Proliferative DR"
}
DR_GRADE_DESCRIPTIONS = {
0: "No signs of diabetic retinopathy detected. The retina appears healthy with no visible abnormalities.",
1: "Mild non-proliferative diabetic retinopathy. Minor microaneurysms may be present.",
2: "Moderate non-proliferative diabetic retinopathy. Microaneurysms, dot/blot hemorrhages, and hard exudates may be visible.",
3: "Severe non-proliferative diabetic retinopathy. Extensive hemorrhages, venous beading, and intraretinal microvascular abnormalities present.",
4: "Proliferative diabetic retinopathy. Neovascularization and/or vitreous/preretinal hemorrhage detected. Urgent referral recommended."
}
# ============================================================
# MODEL DEFINITIONS β€” EXACT COPIES FROM NOTEBOOKS
# ============================================================
class APTOSModel(nn.Module):
"""Classification model β€” exact architecture from efficientnet-b4.ipynb"""
def __init__(self, num_classes=5):
super().__init__()
self.backbone = models.efficientnet_b4(weights=None)
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(in_features, num_classes)
)
def forward(self, x):
return self.backbone(x)
def get_classification_model(path, device):
"""Load classification model β€” exact from pipeline notebook"""
model = APTOSModel()
model.load_state_dict(torch.load(path, map_location=device))
return model.to(device).eval()
def get_segmentation_model(path, device):
"""Load segmentation model β€” exact from pipeline notebook"""
model = smp.Unet(
encoder_name="efficientnet-b3",
encoder_weights=None,
in_channels=3,
classes=4
)
model.load_state_dict(torch.load(path, map_location=device))
return model.to(device).eval()
# ============================================================
# PREPROCESSING FUNCTIONS β€” EXACT COPIES FROM NOTEBOOKS
# ============================================================
def crop_fundus(image, threshold=10):
"""Crop black borders from fundus image β€” exact from pipeline notebook"""
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
_, thresh = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
coords = cv2.findNonZero(thresh)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
return image[y:y+h, x:x+w], (x, y, w, h)
return image, (0, 0, image.shape[1], image.shape[0])
def apply_clahe_green(image):
"""CLAHE on green channel β€” exact from pipeline notebook"""
green = image[:, :, 1]
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
g = clahe.apply(green)
return np.stack([g, g, g], axis=-1)
def overlay_mask(image, mask, alpha=0.5):
"""Overlay colored segmentation mask β€” exact from pipeline notebook"""
overlay = image.copy().astype(np.float64)
colors = {1: [255, 0, 0], 2: [0, 255, 0], 3: [0, 0, 255]}
for cls, color in colors.items():
mask_region = mask == cls
if mask_region.any():
overlay[mask_region] = (1 - alpha) * overlay[mask_region] + alpha * np.array(color)
return overlay.astype(np.uint8)
# ============================================================
# INFERENCE PIPELINE β€” EXACT FROM PIPELINE NOTEBOOK
# ============================================================
@torch.no_grad()
def run_classification(model, image_rgb):
"""Run classification β€” exact preprocessing from pipeline notebook"""
cl_trans = transforms.Compose([
transforms.Resize((CL_IMG_SIZE, CL_IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
cl_inp = cl_trans(Image.fromarray(image_rgb)).unsqueeze(0).to(DEVICE)
logits = model(cl_inp)
probabilities = F.softmax(logits, dim=1).cpu().numpy()[0]
grade = int(np.argmax(probabilities))
return grade, probabilities
@torch.no_grad()
def run_segmentation(model, image_rgb):
"""Run segmentation β€” exact preprocessing from pipeline notebook"""
cropped, coords = crop_fundus(image_rgb)
roi = cv2.resize(cropped, (SEG_IMG_SIZE, SEG_IMG_SIZE))
seg_inp = torch.tensor(
np.transpose(apply_clahe_green(roi).astype(np.float32) / 255.0, (2, 0, 1))
).to(DEVICE)
# Sliding window inference β€” exact from pipeline notebook
out_map = torch.zeros((4, 512, 512), device=DEVICE)
cnt_map = torch.zeros((512, 512), device=DEVICE)
for y in range(0, 512 - 256 + 1, 128):
for x in range(0, 512 - 256 + 1, 128):
patch = seg_inp[:, y:y+256, x:x+256].unsqueeze(0)
out_map[:, y:y+256, x:x+256] += model(patch)[0]
cnt_map[y:y+256, x:x+256] += 1
pred_mask = torch.argmax(out_map / cnt_map.unsqueeze(0), dim=0).cpu().numpy().astype(np.uint8)
return roi, pred_mask
def numpy_to_base64(img_array):
"""Convert numpy RGB image to base64 PNG string"""
img = Image.fromarray(img_array)
buffer = io.BytesIO()
img.save(buffer, format='PNG')
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def calculate_lesion_stats(mask):
"""Calculate pixel-level statistics for each lesion type"""
total_pixels = mask.size
stats = {}
for i in range(1, 4):
count = int(np.sum(mask == i))
stats[SEG_CLASSES[i]] = {
"pixel_count": count,
"percentage": round(count / total_pixels * 100, 4),
"detected": count > 0
}
return stats
# ============================================================
# FLASK APP
# ============================================================
app = Flask(__name__)
CORS(app)
cl_model = None
seg_model = None
def load_models():
global cl_model, seg_model
if cl_model is None:
try:
print(f"Using device: {DEVICE}")
print(f"Loading classification model from: {CL_WEIGHTS}")
cl_model = get_classification_model(CL_WEIGHTS, DEVICE)
print("Classification model loaded βœ…")
except Exception as e:
print(f"Error loading classification model: {e}")
raise
if seg_model is None:
try:
print(f"Loading segmentation model from: {SEG_WEIGHTS}")
seg_model = get_segmentation_model(SEG_WEIGHTS, DEVICE)
print("Segmentation model loaded βœ…")
except Exception as e:
print(f"Error loading segmentation model: {e}")
raise
@app.route('/')
def index():
return jsonify({"message": "Retina-AI Microservice Running"})
@app.route('/api/analyze', methods=['POST'])
def analyze():
"""Main analysis endpoint β€” runs both classification and segmentation"""
load_models()
if 'image' not in request.files:
return jsonify({'error': 'No image file provided'}), 400
file = request.files['image']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
try:
# Read image
file_bytes = np.frombuffer(file.read(), np.uint8)
img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
if img_bgr is None:
return jsonify({'error': 'Could not decode image. Please upload a valid image file.'}), 400
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# 1. Classification
grade, probabilities = run_classification(cl_model, img_rgb)
# 2. Segmentation
roi, pred_mask = run_segmentation(seg_model, img_rgb)
# 3. Create overlay
overlay = overlay_mask(roi, pred_mask)
# 4. Calculate lesion stats
lesion_stats = calculate_lesion_stats(pred_mask)
# 5. Determine if any lesions detected
any_lesions = any(s["detected"] for s in lesion_stats.values())
# 6. Build response
response = {
'success': True,
'classification': {
'grade': grade,
'label': DR_GRADE_LABELS[grade],
'description': DR_GRADE_DESCRIPTIONS[grade],
'confidence': {
DR_GRADE_LABELS[i]: round(float(probabilities[i]) * 100, 2)
for i in range(5)
}
},
'segmentation': {
'overlay_image': numpy_to_base64(overlay),
'original_roi': numpy_to_base64(roi),
'lesion_stats': lesion_stats,
'any_lesions_detected': any_lesions
}
}
return jsonify(response)
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': f'Analysis failed: {str(e)}'}), 500
@app.route('/api/health', methods=['GET'])
def health():
return jsonify({
'status': 'ok',
'device': str(DEVICE),
'models_loaded': cl_model is not None and seg_model is not None
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001, debug=False)