Spaces:
Runtime error
Runtime error
| # Directory Structure Suggestion: | |
| # diabetic_retinopathy_app/ | |
| # ├── Home.py (Landing Page) | |
| # ├── pages/ | |
| # │ ├── 1_Upload_and_Predict.py | |
| # │ └── 2_Model_Evaluation.py | |
| # └── assets/ | |
| # └── banner.jpg | |
| # Home.py (Landing Page) | |
| import streamlit as st | |
| from PIL import Image | |
| def main(): | |
| st.set_page_config(page_title="DR Assistive Tool", layout="centered") | |
| st.title("Welcome to the Diabetic Retinopathy Assistive Tool") | |
| st.markdown(""" | |
| ### 🌟 Your AI-powered assistant for early detection of Diabetic Retinopathy. | |
| #### Features: | |
| - 🖼️ Upload a retinal image and receive a prediction of its DR stage. | |
| - 📊 Evaluate model performance using real test datasets. | |
| Select a page from the left sidebar to get started. | |
| """) | |
| # image = Image.open("assets/banner.jpg") # Optional banner image | |
| # st.image(image, use_column_width=True) | |
| if __name__ == '__main__': | |
| main() | |
| # pages/1_Upload_and_Predict.py | |
| import streamlit as st | |
| import torch | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import numpy as np | |
| st.title("📷 Upload & Predict Diabetic Retinopathy") | |
| class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'] | |
| def load_model(): | |
| model = models.densenet121(pretrained=False) | |
| num_ftrs = model.classifier.in_features | |
| model.classifier = torch.nn.Linear(num_ftrs, len(class_names)) | |
| model.load_state_dict(torch.load("training/Pretrained_Densenet-121.pth", map_location='cpu')) | |
| model.eval() | |
| return model | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict_image(model, image): | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| _, pred = torch.max(outputs, 1) | |
| prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100 | |
| return class_names[pred.item()], prob | |
| uploaded_file = st.file_uploader("Choose a retinal image", type=["jpg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| st.image(image, caption='Uploaded Retinal Image', use_column_width=True) | |
| if st.button("🧠 Predict"): | |
| with st.spinner('Analyzing image...'): | |
| model = load_model() | |
| pred_class, prob = predict_image(model, image) | |
| st.success(f"Prediction: **{pred_class}** ({prob:.2f}% confidence)") | |
| # pages/2_Model_Evaluation.py | |
| import streamlit as st | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import datasets, transforms, models | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| st.title("📈 Model Evaluation on Test Dataset") | |
| def load_test_data(): | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| test_data = datasets.ImageFolder("test_dataset_path", transform=transform) | |
| return DataLoader(test_data, batch_size=32, shuffle=False) | |
| def evaluate(model, loader): | |
| model.eval() | |
| correct, total, loss = 0, 0, 0.0 | |
| criterion = nn.CrossEntropyLoss() | |
| with torch.no_grad(): | |
| for inputs, labels in loader: | |
| outputs = model(inputs) | |
| loss += criterion(outputs, labels).item() | |
| _, pred = torch.max(outputs, 1) | |
| correct += (pred == labels).sum().item() | |
| total += labels.size(0) | |
| return loss / len(loader), correct / total * 100 | |
| if st.button("🧪 Evaluate Trained Model"): | |
| test_loader = load_test_data() | |
| model = models.densenet121(pretrained=False) | |
| model.classifier = nn.Linear(model.classifier.in_features, 5) | |
| model.load_state_dict(torch.load("dr_densenet121.pth", map_location='cpu')) | |
| model.eval() | |
| loss, acc = evaluate(model, test_loader) | |
| st.write(f"**Test Loss:** {loss:.4f}") | |
| st.write(f"**Test Accuracy:** {acc:.2f}%") | |