### 1. Imports and class names setup ### import gradio as gr import os import torch from timeit import default_timer as timer from typing import Tuple, Dict import torchvision from torch import nn from torchvision.models import densenet121 def create_densenet121_model(num_classes: int = 2, seed: int = 42): """Creates a DenseNet121 model and transforms.""" # Create DenseNet121 model model = densenet121(weights=None) # Set to None since we will be loading our own weights # Freeze all layers in base model for param in model.parameters(): param.requires_grad = False # Change classifier head with random seed for reproducibility torch.manual_seed(seed) model.classifier = nn.Linear(model.classifier.in_features, num_classes) transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return model, transforms # Create densenet121 model densenet, densenet_transforms = create_densenet121_model() # Load saved weights state_dict = torch.load("model/FL_global_model_4be885f7-8d33-4498-a5ef-85aa301706bd.pt", map_location=torch.device("cpu")) model_weights = state_dict["model"] densenet.load_state_dict(model_weights,strict=False) # Set strict to True since we now expect it to match def predict(img) -> Tuple[Dict, float]: """Transforms and performs a prediction on img and returns prediction and time taken.""" # Start the timer start_time = timer() # Transform the target image and add a batch dimension img = densenet_transforms(img).unsqueeze(0) # Put model into evaluation mode and turn on inference mode densenet.eval() with torch.inference_mode(): pred_probs = torch.softmax(densenet(img), dim=1).squeeze() pred_labels_and_probs = { 'Nodules': pred_probs[0].item(), 'Normal': pred_probs[1].item() } # Calculate the prediction time pred_time = round(timer() - start_time, 5) # Return the prediction dictionary and prediction time return pred_labels_and_probs, pred_time example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)] title = "ChestXray Classification" description = "A Densenet121 computer vision model to classify images of Xray Chest images as Normal or Nodules." article = "model train by hamsteryang0" # article = "Created at (https://github.com/azizche/chest_xray_Classification)." demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label="Prediction time (s)")], examples=example_list, title=title, description=description, article=article ) # Launch the demo! demo.launch()