Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import numpy as np | |
| import pandas as pd | |
| from collections import defaultdict | |
| import os | |
| from datasets import load_dataset | |
| # Title | |
| st.markdown("<h2 style='color: #2E86C1;'>π· Upload & Predict</h2>", unsafe_allow_html=True) | |
| st.markdown(""" | |
| ### π About This Feature: Upload & Predict | |
| This section of the **DR Assistive Tool** allows users to upload retinal images and get an AI-based prediction of the **Diabetic Retinopathy stage**. It uses a fine-tuned **DenseNet-121** model trained specifically for detecting DR severity levels. | |
| The model classifies the uploaded image into one of the five classes: | |
| - **No DR** | |
| - **Mild** | |
| - **Moderate** | |
| - **Severe** | |
| - **Proliferative DR** | |
| This is especially helpful for: | |
| - Students learning about AI in healthcare | |
| - Researchers testing model robustness | |
| - Clinicians exploring AI-assisted screening tools | |
| The tool also shows **sample images from the test set** for each class. You can use these images to test the modelβs performance and understand what different DR stages look like. | |
| --- | |
| ### π§ How to Use: | |
| 1. π **View sample images** from the test set grouped by DR stage. | |
| - Click the **"π Predict"** button under a sample image to test how the model classifies it. | |
| 2. π **Upload your own retinal image** (in JPG or PNG format) using the file uploader. | |
| 3. π§ Click the **"Predict"** button after uploading. | |
| - The model will analyze the image and display: | |
| - π― **Predicted DR Stage** | |
| - π **Model confidence score (in %)** | |
| β οΈ *Make sure your image is a clear, centered fundus photograph for best results.* | |
| --- | |
| ### π Behind the Scenes: | |
| - β Model: Pretrained **DenseNet-121** | |
| - πΌ Input size: Images are resized to 224Γ224 pixels | |
| - π Normalization: Matches ImageNet pretraining stats | |
| - π¦ Output: Highest probability class from 5 DR categories using **softmax** | |
| *This tool is for educational and research purposes only β not for clinical use.* | |
| """, unsafe_allow_html=True) | |
| # DR class names | |
| class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'] | |
| def load_sample_images_from_csv(): | |
| csv_url = "https://huggingface.co/datasets/Ci-Dave/DDR_dataset_train_test/raw/main/splits/test_labels.csv" | |
| df = pd.read_csv(csv_url) | |
| samples = defaultdict(list) | |
| for i in range(5): | |
| class_name = class_names[i] | |
| class_samples = df[df['label'] == i].head(5) | |
| for _, row in class_samples.iterrows(): | |
| img_path = row['new_path'] | |
| if os.path.exists(img_path): # works only if images are local | |
| samples[class_name].append(img_path) | |
| return samples | |
| # Load pretrained model | |
| def load_model(): | |
| model = models.densenet121(pretrained=False) | |
| model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names)) | |
| model.load_state_dict(torch.load("./Model/Pretrained_Densenet-121.pth", map_location='cpu')) | |
| model.eval() | |
| return model | |
| # Image transform function | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Prediction function | |
| def predict_image(model, image): | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| _, pred = torch.max(outputs, 1) | |
| prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100 | |
| return class_names[pred.item()], prob | |
| # Create two tabs for better separation of features | |
| tab1, tab2 = st.tabs(["π§ͺ Sample Images", "π€ Upload & Predict"]) | |
| with tab1: | |
| st.markdown("### π§ͺ Sample Images from Test Set") | |
| st.markdown(""" | |
| #### π About This Feature: Sample Images | |
| In this tab, you can explore sample retinal images from the test set, grouped by their **Diabetic Retinopathy (DR)** stage. This helps you: | |
| - Understand the **visual differences** between DR stages | |
| - Test the modelβs performance on known data | |
| - Get familiar with the modelβs prediction behavior | |
| #### π§ How to Use: | |
| 1. Browse the sample images under each DR class. | |
| 2. Click **π Predict** under an image to let the AI model analyze it. | |
| 3. The result will show: | |
| - π― **Predicted DR stage** | |
| - π **Confidence score** | |
| > *Ideal for researchers and students testing the model with known data.* | |
| """, unsafe_allow_html=True) | |
| sample_images = load_sample_images_from_csv() | |
| for class_name in class_names: | |
| if class_name in sample_images and sample_images[class_name]: | |
| cols = st.columns(5) | |
| for i, img_path in enumerate(sample_images[class_name]): | |
| with cols[i]: | |
| st.image(img_path, use_container_width=True) | |
| if st.button("π Predict", key=f"predict_{img_path}_{i}"): | |
| image = Image.open(img_path).convert('RGB') | |
| model = load_model() | |
| pred_class, prob = predict_image(model, image) | |
| st.success(f"π― Prediction: **{pred_class}** ({prob:.2f}% confidence)") | |
| else: | |
| st.warning(f"β οΈ No images found for **{class_name}**") | |
| with tab2: | |
| st.markdown("### π€ Upload & Predict") | |
| st.markdown(""" | |
| #### π About This Feature: Upload & Predict | |
| This tool allows you to upload a **retinal image** and get an **AI-based prediction** of the DR stage using a fine-tuned **DenseNet-121** model. | |
| The model classifies the image into one of: | |
| - No DR | |
| - Mild | |
| - Moderate | |
| - Severe | |
| - Proliferative DR | |
| #### π§ How to Use: | |
| 1. π Upload a **clear fundus image** (JPG or PNG). | |
| 2. π§ Click **Predict** to let the model analyze it. | |
| 3. β You'll see: | |
| - π― The predicted DR stage | |
| - π Confidence level (in percentage) | |
| """, unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader("π Upload Retinal Image", type=["jpg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| st.image(image, caption='πΌ Uploaded Image', use_container_width=True) | |
| if st.button("π§ Predict"): | |
| with st.spinner('Analyzing image...'): | |
| model = load_model() | |
| pred_class, prob = predict_image(model, image) | |
| st.success(f"π― Prediction: **{pred_class}** ({prob:.2f}% confidence)") | |