Spaces:
Sleeping
Sleeping
File size: 2,845 Bytes
d213efc 7998818 d213efc ea236cf 7998818 ea236cf d213efc 7998818 d213efc 7998818 d213efc 7998818 d213efc 7998818 ea236cf d213efc b74d07d 415d148 b8d480c d213efc ea236cf d213efc 415d148 d213efc 415d148 d213efc ea236cf d213efc 5c276ab ea236cf d213efc ea236cf d213efc ea236cf 5c276ab ea236cf c8ef432 d213efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from timeit import default_timer as timer
from typing import Tuple, Dict
import torchvision
from torch import nn
from torchvision.models import densenet121
def create_densenet121_model(num_classes: int = 2, seed: int = 42):
"""Creates a DenseNet121 model and transforms."""
# Create DenseNet121 model
model = densenet121(weights=None) # Set to None since we will be loading our own weights
# Freeze all layers in base model
for param in model.parameters():
param.requires_grad = False
# Change classifier head with random seed for reproducibility
torch.manual_seed(seed)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return model, transforms
# Create densenet121 model
densenet, densenet_transforms = create_densenet121_model()
# Load saved weights
state_dict = torch.load("model/FL_global_model_4be885f7-8d33-4498-a5ef-85aa301706bd.pt", map_location=torch.device("cpu"))
model_weights = state_dict["model"]
densenet.load_state_dict(model_weights,strict=False) # Set strict to True since we now expect it to match
def predict(img) -> Tuple[Dict, float]:
"""Transforms and performs a prediction on img and returns prediction and time taken."""
# Start the timer
start_time = timer()
# Transform the target image and add a batch dimension
img = densenet_transforms(img).unsqueeze(0)
# Put model into evaluation mode and turn on inference mode
densenet.eval()
with torch.inference_mode():
pred_probs = torch.softmax(densenet(img), dim=1).squeeze()
pred_labels_and_probs = {
'Nodules': pred_probs[0].item(),
'Normal': pred_probs[1].item()
}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time
example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)]
title = "ChestXray Classification"
description = "A Densenet121 computer vision model to classify images of Xray Chest images as Normal or Nodules."
article = "model train by hamsteryang0"
# article = "Created at (https://github.com/azizche/chest_xray_Classification)."
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label="Prediction time (s)")],
examples=example_list,
title=title,
description=description,
article=article
)
# Launch the demo!
demo.launch()
|