File size: 2,845 Bytes
d213efc
 
 
 
 
 
 
 
 
 
7998818
d213efc
ea236cf
7998818
 
 
ea236cf
d213efc
 
 
 
 
 
 
7998818
d213efc
7998818
 
 
 
 
d213efc
7998818
d213efc
7998818
ea236cf
d213efc
 
b74d07d
415d148
b8d480c
d213efc
 
ea236cf
d213efc
 
 
 
415d148
d213efc
 
415d148
d213efc
ea236cf
d213efc
 
5c276ab
 
ea236cf
d213efc
 
 
 
 
 
 
 
ea236cf
d213efc
ea236cf
5c276ab
 
ea236cf
 
 
 
 
 
 
 
 
c8ef432
d213efc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch

from timeit import default_timer as timer
from typing import Tuple, Dict
import torchvision

from torch import nn
from torchvision.models import densenet121

def create_densenet121_model(num_classes: int = 2, seed: int = 42):
    """Creates a DenseNet121 model and transforms."""
    
    # Create DenseNet121 model
    model = densenet121(weights=None)  # Set to None since we will be loading our own weights

    # Freeze all layers in base model
    for param in model.parameters():
        param.requires_grad = False

    # Change classifier head with random seed for reproducibility
    torch.manual_seed(seed)
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    return model, transforms

# Create densenet121 model
densenet, densenet_transforms = create_densenet121_model()

# Load saved weights
state_dict = torch.load("model/FL_global_model_4be885f7-8d33-4498-a5ef-85aa301706bd.pt", map_location=torch.device("cpu"))
model_weights = state_dict["model"]
densenet.load_state_dict(model_weights,strict=False)  # Set strict to True since we now expect it to match

def predict(img) -> Tuple[Dict, float]:
    """Transforms and performs a prediction on img and returns prediction and time taken."""
    # Start the timer
    start_time = timer()

    # Transform the target image and add a batch dimension
    img = densenet_transforms(img).unsqueeze(0)

    # Put model into evaluation mode and turn on inference mode
    densenet.eval()
    with torch.inference_mode():
        pred_probs = torch.softmax(densenet(img), dim=1).squeeze()

    pred_labels_and_probs = {
        'Nodules': pred_probs[0].item(), 
        'Normal': pred_probs[1].item()
    }

    # Calculate the prediction time
    pred_time = round(timer() - start_time, 5)

    # Return the prediction dictionary and prediction time
    return pred_labels_and_probs, pred_time

example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)]

title = "ChestXray Classification"
description = "A Densenet121 computer vision model to classify images of Xray Chest images as Normal or Nodules."
article = "model train by hamsteryang0"
# article = "Created at (https://github.com/azizche/chest_xray_Classification)."
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label="Prediction time (s)")],
    examples=example_list,
    title=title,
    description=description,
    article=article
)

# Launch the demo!
demo.launch()