### 1. Imports and class names setup ### import gradio as gr import os import torch from timeit import default_timer as timer from typing import Tuple, Dict import torchvision from torch import nn def create_effnetb2_model(num_classes: int = 1, seed: int = 42): """Creates an EfficientNetB2 feature extractor model and transforms. Args: num_classes (int, optional): number of classes in the classifier head. Defaults to 3. seed (int, optional): random seed value. Defaults to 42. Returns: model (torch.nn.Module): EffNetB2 feature extractor model. transforms (torchvision.transforms): EffNetB2 image transforms. """ # Create EffNetB2 pretrained weights, transforms and model weights = torchvision.models.AlexNet_Weights.DEFAULT transforms = weights.transforms() model = torchvision.models.alexnet(weights=weights) # Freeze all layers in base model for param in model.parameters(): param.requires_grad = False # Change classifier head with random seed for reproducibility torch.manual_seed(seed) model.classifier = nn.Sequential( nn.Dropout(p=0.2,), nn.Linear(in_features=9216, out_features=1), ) return model, transforms # Setup class names class_names = ["Normal", "Pneumonia"] ### 2. Model and transforms preparation ### # Create EffNetB2 model effnetb2, effnetb2_transforms = create_effnetb2_model( num_classes=1, # len(class_names) would also work ) # Load saved weights effnetb2.load_state_dict( torch.load( f="alexnet_pretrained.pth", map_location=torch.device("cpu"), # load to CPU ) ) def predict(img) -> Tuple[Dict, float]: """Transforms and performs a prediction on img and returns prediction and time taken. """ # Start the timer start_time = timer() # Transform the target image and add a batch dimension img = effnetb2_transforms(img).unsqueeze(0) # Put model into evaluation mode and turn on inference mode effnetb2.eval() with torch.inference_mode(): # Pass the transformed image through the model and turn the prediction logits into prediction probabilities pred_probs = torch.sigmoid(effnetb2(img)).squeeze() # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter) pred_labels_and_probs = { 'Normal': 1-pred_probs.item(), 'Pneumonia': pred_probs.item()} # Calculate the prediction time pred_time = round(timer() - start_time, 5) # Return the prediction dictionary and prediction time return pred_labels_and_probs, pred_time example_list = [[f"examples/example{i+1}.jpg"] for i in range(3)] # Create title, description and article strings title = "ChestXray Classification" description = "An Alexnet computer vision model to classify images of Xray Chest images as Normal or Pneumonia." article = "Created at (https://github.com/azizche/chest_xray_Classification)." # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=gr.Image(type="pil"), # what are the inputs? outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs examples=example_list, title=title, description=description, article=article) # Launch the demo! demo.launch()