Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from torchvision import models | |
| import gradio as gr | |
| # Define transformations (must be the same as those used during training) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Load the model architecture and weights | |
| model = models.resnet50(weights=None) # Initialize model without pretrained weights | |
| model.fc = nn.Linear(model.fc.in_features, 4) # Adjust final layer for 4 classes | |
| # Load the state dictionary with map_location for CPU | |
| model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu'))) | |
| model.eval() # Set model to evaluation mode | |
| # Define class labels (must match the dataset used during training) | |
| class_labels = ["Mild_Demented 0", "Moderate_Demented 1", "Non_Demented 2", "Very_Mild_Demented 3"] # Replace with your class names | |
| # Define the prediction function | |
| def predict(image): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| else: | |
| image = Image.open(image).convert("RGB") | |
| image = transform(image).unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| _, predicted = torch.max(outputs.data, 1) | |
| label = class_labels[predicted.item()] | |
| return label | |
| # Create a Gradio interface with examples | |
| examples = [ | |
| ["image.jpg"], | |
| ["image (1).jpg"], | |
| ["image (2).jpg"], | |
| ["image (3).jpg"] | |
| ] | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy", label="Upload an MRI Image"), | |
| outputs=gr.Textbox(label="Prediction"), | |
| title="Alzheimer MRI Classification", | |
| examples=examples | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |