hamsteryang's picture
Update app.py
5aa8756
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from timeit import default_timer as timer
from typing import Tuple, Dict
import torchvision
from torch import nn
from torchvision.models import densenet121
def create_densenet121_model(num_classes: int = 2, seed: int = 42):
"""Creates a DenseNet121 model and transforms."""
# Create DenseNet121 model
model = densenet121(weights=None) # Set to None since we will be loading our own weights
# Freeze all layers in base model
for param in model.parameters():
param.requires_grad = False
# Change classifier head with random seed for reproducibility
torch.manual_seed(seed)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return model, transforms
# Create densenet121 model
densenet, densenet_transforms = create_densenet121_model()
# Load saved weights
state_dict = torch.load("model/FL_global_model_4be885f7-8d33-4498-a5ef-85aa301706bd.pt", map_location=torch.device("cpu"))
model_weights = state_dict["model"]
densenet.load_state_dict(model_weights,strict=False) # Set strict to True since we now expect it to match
def predict(img) -> Tuple[Dict, float]:
"""Transforms and performs a prediction on img and returns prediction and time taken."""
# Start the timer
start_time = timer()
# Transform the target image and add a batch dimension
img = densenet_transforms(img).unsqueeze(0)
# Put model into evaluation mode and turn on inference mode
densenet.eval()
with torch.inference_mode():
pred_probs = torch.softmax(densenet(img), dim=1).squeeze()
pred_labels_and_probs = {
'Nodules': pred_probs[0].item(),
'Normal': pred_probs[1].item()
}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time
example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)]
title = "ChestXray Classification"
description = "A Densenet121 computer vision model to classify images of Xray Chest images as Normal or Nodules."
article = "model train by hamsteryang0"
# article = "Created at (https://github.com/azizche/chest_xray_Classification)."
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label="Prediction time (s)")],
examples=example_list,
title=title,
description=description,
article=article
)
# Launch the demo!
demo.launch()