| | |
| | import streamlit as st |
| | import torch |
| | import torch.nn as nn |
| | from PIL import Image |
| | import numpy as np |
| | from torchvision import transforms |
| | from models.model import DeepLearningModel |
| | import joblib |
| | from scripts.xai_eval import convert_to_gradcam |
| | import cv2 |
| |
|
| |
|
| | |
| | BUCKET_NAME = "aipi540-cv" |
| | VERTEX_AI_ENDPOINT = "" |
| |
|
| | |
| | class_names = ["Normal", "Mild Diabetic Retinopathy", "Severe Diabetic Retinopathy"] |
| |
|
| |
|
| | |
| | class ModelHandler: |
| | def __init__(self): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.3205, 0.2244, 0.1613], |
| | std=[0.2996,0.2158, 0.1711]) |
| | ]) |
| |
|
| | |
| | |
| | def preprocess_image(self, image): |
| | """Preprocess the image for model input.""" |
| | img_tensor = self.transform(image) |
| | return img_tensor.unsqueeze(0).to(self.device) |
| |
|
| | def load_model(model_type): |
| | handler = ModelHandler() |
| | device = handler.device |
| | |
| | |
| | model = DeepLearningModel() |
| | model.load_state_dict(torch.load("models/vgg16_model.pth", map_location=device)) |
| | |
| | model = model.to(device) |
| | if hasattr(model, 'eval'): |
| | model.eval() |
| | |
| | return model, handler.preprocess_image |
| |
|
| | |
| | def predict(model, image_tensor): |
| | '''Predict the class of the input image using the given model.''' |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.eval() |
| |
|
| | with torch.no_grad(): |
| | image_tensor = image_tensor.to(device) |
| | outputs = model(image_tensor) |
| | |
| | |
| | probabilities = torch.softmax(outputs, dim=1) |
| | predicted_class = torch.argmax(probabilities, dim=1).item() |
| | class_probabilities = probabilities[0].cpu().numpy() |
| |
|
| | return predicted_class, class_probabilities |
| |
|
| | def generate_gradcam(model, image_tensor): |
| | """Generate Grad-CAM heatmap for the input image using the given model.""" |
| | try: |
| | cam = convert_to_gradcam(model) |
| | heatmap = cam(input_tensor=image_tensor, targets=None) |
| |
|
| | |
| | if isinstance(heatmap, torch.Tensor): |
| | heatmap = heatmap.squeeze().cpu().numpy() |
| | else: |
| | heatmap = heatmap.squeeze() |
| |
|
| | |
| | heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) |
| | heatmap = np.uint8(255 * heatmap) |
| |
|
| | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
| |
|
| | return heatmap |
| | except Exception as e: |
| | return f"Grad-CAM Error: {str(e)}" |
| |
|
| | |
| | def main(): |
| | st.set_page_config( |
| | page_title="Diabetic Retinopathy Prediction", |
| | page_icon="👁️", |
| | layout="wide" |
| | ) |
| | |
| | st.title("👁️ Diabetic Retinopathy Detection System") |
| | st.write("Upload a fundus image to detect diabetic retinopathy severity") |
| |
|
| | |
| | st.sidebar.header("Model Information") |
| | st.sidebar.write("Using Deep Learning Model (VGG16)") |
| |
|
| | |
| | model, preprocess = load_model("Deep Learning Model") |
| |
|
| | st.sidebar.header("About") |
| | st.sidebar.markdown(""" |
| | This system aims to detect diabetic retinopathy (DR) from fundus images. |
| | ### Model: |
| | - **Deep Learning Model**: VGG16-based architecture |
| | |
| | ### Classes: |
| | - Normal (No DR) |
| | - Mild DR |
| | - Severe DR |
| | """) |
| |
|
| | st.header("Image Upload") |
| | uploaded_file = st.file_uploader( |
| | "Choose a fundus image", |
| | type=["jpg", "jpeg", "png"] |
| | ) |
| |
|
| | if uploaded_file is not None: |
| | image = Image.open(uploaded_file).convert('RGB') |
| | st.image(image, caption="Uploaded Image", use_container_width=True) |
| |
|
| | if st.button("Analyze Image"): |
| | try: |
| | processed_image = preprocess(image) |
| |
|
| | with st.spinner("Analyzing image..."): |
| | predicted_class, class_probs = predict(model, processed_image) |
| |
|
| | st.success("Analysis Complete!") |
| |
|
| | |
| | st.header("Prediction Results") |
| | st.write(f"**Predicted Condition:** {class_names[predicted_class]}") |
| | st.write("**Class Probabilities:**") |
| | st.json({class_names[i]: float(class_probs[i]) for i in range(len(class_probs))}) |
| |
|
| | |
| | with st.spinner("Generating XAI..."): |
| | heatmap = generate_gradcam(model, processed_image) |
| |
|
| | st.header("Grad-CAM Explanation") |
| | if isinstance(heatmap, str): |
| | st.error(heatmap) |
| | else: |
| | st.image(heatmap, caption="Grad-CAM Heatmap", use_container_width=True) |
| |
|
| | except Exception as e: |
| | st.error(f"Error during analysis: {str(e)}") |
| |
|
| | |
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | main() |