| | from flask import Flask, request, jsonify, send_file |
| | from PIL import Image |
| | import torch |
| | import torch.nn.functional as F |
| | from torchvision import transforms |
| | import os |
| | import numpy as np |
| | from datetime import datetime |
| | import sqlite3 |
| | import torch.nn as nn |
| | import cv2 |
| | import json |
| |
|
| | |
| | from pytorch_grad_cam import GradCAMPlusPlus |
| | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| | from pytorch_grad_cam.utils.image import show_cam_on_image |
| |
|
| | app = Flask(__name__) |
| |
|
| | |
| | OUTPUT_DIR = '/tmp/results' |
| | os.makedirs(OUTPUT_DIR, exist_ok=True) |
| |
|
| | DB_PATH = os.path.join(OUTPUT_DIR, 'results.db') |
| |
|
| |
|
| | def init_db(): |
| | """Initialize SQLite database for storing results.""" |
| | conn = sqlite3.connect(DB_PATH) |
| | cursor = conn.cursor() |
| | cursor.execute(""" |
| | CREATE TABLE IF NOT EXISTS results ( |
| | id INTEGER PRIMARY KEY AUTOINCREMENT, |
| | image_filename TEXT, |
| | prediction TEXT, |
| | confidence REAL, |
| | gradcam_filename TEXT, |
| | gradcam_gray_filename TEXT, |
| | timestamp TEXT |
| | ) |
| | """) |
| | conn.commit() |
| | conn.close() |
| |
|
| |
|
| | init_db() |
| |
|
| |
|
| | |
| | from efficientnet_transformer_glam import EfficientNetb0_TransformerGLAM |
| |
|
| |
|
| | |
| | model = EfficientNetb0_TransformerGLAM( |
| | num_classes=3, |
| | embed_dim=512, |
| | num_heads=8, |
| | mlp_dim=512, |
| | dropout=0.5, |
| | window_size=7, |
| | reduction_ratio=8 |
| | ) |
| |
|
| | |
| | model.load_state_dict(torch.load('efficientnet_glam_best.pt', map_location='cpu')) |
| | model.eval() |
| |
|
| | |
| | CLASS_NAMES = ["Advanced", "Early", "Normal"] |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | @app.route('/') |
| | def home(): |
| | """Check that the API is working.""" |
| | return "Glaucoma Detection Flask API (EfficientNetB0_TransformerGLAM Model) is running!" |
| |
|
| | @app.route("/test_file") |
| | def test_file(): |
| | """Check if the .pt model file is present and readable.""" |
| | filepath = "densenet169_seed40_best.pt" |
| | if os.path.exists(filepath): |
| | return f"β
Model file found at: {filepath}" |
| | else: |
| | return "β Model file NOT found." |
| |
|
| | @app.route('/predict', methods=['POST']) |
| | def predict(): |
| | """Perform prediction and save results (including Grad-CAM++) to the database.""" |
| | if 'file' not in request.files: |
| | return jsonify({'error': 'No file uploaded.'}), 400 |
| |
|
| | uploaded_file = request.files['file'] |
| | if uploaded_file.filename == '': |
| | return jsonify({'error': 'No file selected.'}), 400 |
| |
|
| | try: |
| | |
| | timestamp = int(datetime.now().timestamp()) |
| | uploaded_filename = f"uploaded_{timestamp}.png" |
| | uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename) |
| | uploaded_file.save(uploaded_file_path) |
| |
|
| | |
| | img = Image.open(uploaded_file_path).convert('RGB') |
| | input_tensor = transform(img).unsqueeze(0) |
| |
|
| | |
| | with torch.no_grad(): |
| | output = model(input_tensor) |
| | probabilities = F.softmax(output, dim=1).cpu().numpy()[0] |
| | class_index = np.argmax(probabilities) |
| | result = CLASS_NAMES[class_index] |
| | confidence = float(probabilities[class_index]) |
| |
|
| | |
| | target_layer = model.fusion_block |
| | cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer]) |
| |
|
| | cam_output = cam_model(input_tensor=input_tensor, |
| | targets=[ClassifierOutputTarget(class_index)])[0] |
| |
|
| | |
| | original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0 |
| | overlay = show_cam_on_image(original_img, cam_output, use_rgb=True) |
| | |
| | |
| | cam_normalized = np.uint8(255 * cam_output) |
| |
|
| | |
| | gradcam_filename = f"gradcam_{timestamp}.png" |
| | gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename) |
| | cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) |
| |
|
| | |
| | gray_filename = f"gradcam_gray_{timestamp}.png" |
| | gray_file_path = os.path.join(OUTPUT_DIR, gray_filename) |
| | cv2.imwrite(gray_file_path, cam_normalized) |
| |
|
| | |
| | conn = sqlite3.connect(DB_PATH) |
| | cursor = conn.cursor() |
| | cursor.execute(""" |
| | INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp) |
| | VALUES (?, ?, ?, ?, ?, ?) |
| | """, (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat())) |
| | conn.commit() |
| | conn.close() |
| |
|
| | |
| | return jsonify({ |
| | 'prediction': result, |
| | 'confidence': confidence, |
| | 'normal_probability': float(probabilities[0]), |
| | 'early_glaucoma_probability': float(probabilities[1]), |
| | 'advanced_glaucoma_probability': float(probabilities[2]), |
| | 'gradcam_image': gradcam_filename, |
| | 'gradcam_gray_image': gray_filename, |
| | 'image_filename': uploaded_filename |
| | }) |
| |
|
| | except Exception as e: |
| | return jsonify({'error': str(e)}), 500 |
| |
|
| |
|
| | @app.route('/results', methods=['GET']) |
| | def results(): |
| | """List all results from the SQLite database.""" |
| | conn = sqlite3.connect(DB_PATH) |
| | cursor = conn.cursor() |
| | cursor.execute("SELECT * FROM results ORDER BY timestamp DESC") |
| | results_data = cursor.fetchall() |
| | conn.close() |
| |
|
| | results_list = [] |
| | for record in results_data: |
| | results_list.append({ |
| | 'id': record[0], |
| | 'image_filename': record[1], |
| | 'prediction': record[2], |
| | 'confidence': record[3], |
| | 'gradcam_filename': record[4], |
| | 'gradcam_gray_filename': record[5], |
| | 'timestamp': record[6] |
| | }) |
| |
|
| | return jsonify(results_list) |
| |
|
| |
|
| | @app.route('/gradcam/<filename>') |
| | def get_gradcam(filename): |
| | """Serve the Grad-CAM overlay image.""" |
| | filepath = os.path.join(OUTPUT_DIR, filename) |
| | if os.path.exists(filepath): |
| | return send_file(filepath, mimetype='image/png') |
| | else: |
| | return jsonify({'error': 'File not found.'}), 404 |
| |
|
| |
|
| | @app.route('/image/<filename>') |
| | def get_image(filename): |
| | """Serve the original uploaded image.""" |
| | filepath = os.path.join(OUTPUT_DIR, filename) |
| | if os.path.exists(filepath): |
| | return send_file(filepath, mimetype='image/png') |
| | else: |
| | return jsonify({'error': 'File not found.'}), 404 |
| |
|
| |
|
| | if __name__ == '__main__': |
| | app.run(host='0.0.0.0', port=7860) |
| |
|
| |
|