Thenujan commited on
Commit
adeca79
·
1 Parent(s): 659f247

Added Gradio

Browse files
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_model
7
+ from timeit import default_timer as timer
8
+ import torchvision
9
+ import torchvision.transforms as transforms
10
+
11
+ transformer = transforms.Compose([
12
+ transforms.Resize(256),
13
+ transforms.CenterCrop(256),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
16
+ ])
17
+
18
+ ### 2. Model and transforms preparation ###
19
+
20
+ # Create model
21
+ model = create_model(
22
+ num_classes=3, # len(class_names) would also work
23
+ )
24
+
25
+
26
+ # Load saved weights
27
+ model.load_state_dict(
28
+ torch.load(
29
+ f="09_pretrained_effnetb2_feature_extractor_pizza_steak_sushi_20_percent.pth",
30
+ map_location=torch.device("cpu"), # load to CPU
31
+ )
32
+ )
33
+
34
+ ### 3. Predict function ###
35
+
36
+ from typing import Tuple, Dict
37
+
38
+ def predict(img):
39
+ """Transforms and performs a prediction on img and returns prediction and time taken.
40
+ """
41
+ # Transform the target image and add a batch dimension
42
+ img = transformer(img).unsqueeze(0)
43
+
44
+
45
+ # Put model into evaluation mode and turn on inference mode
46
+ model.eval()
47
+ with torch.inference_mode():
48
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
49
+ pred_prob = torch.sigmoid(model(img))
50
+
51
+ pred_probs = {'Covid' : float(pred_prob), 'Non Covid' : (1-float(pred_prob))}
52
+
53
+ # Return the prediction dictionary and prediction time
54
+ return pred_probs
55
+
56
+ ### 4. Gradio app ###
57
+
58
+
59
+ # Create title, description and article strings
60
+ title = "Corona Prediction"
61
+ description = "A Convolutional Neural Network To classify whether a person have Corona or not using CT Scans."
62
+ article = "Created by Thenujan Nagaratnam for DNN module at UoM"
63
+
64
+ # Create examples list from "examples/" directory
65
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
66
+
67
+ # Create the Gradio demo
68
+ demo = gr.Interface(fn=predict, # mapping function from input to output
69
+ inputs=gr.Image(type="pil"), # what are the inputs?
70
+ outputs=[gr.Label(num_top_classes=2, label="Predictions")], # our fn has two outputs, therefore we have two outputs
71
+ examples=example_list,
72
+ title=title,
73
+ description=description,
74
+ article=article)
75
+
76
+
77
+ # Launch the demo!
78
+ demo.launch()
examples/1%2.jpg ADDED
examples/2020.01.24.919183-p27-133.png ADDED
examples/2020.02.10.20021584-p6-52%10.png ADDED
model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.models as models
4
+
5
+ from torch import nn
6
+
7
+
8
+ def create_effnetb2_model(num_classes:int=3,
9
+ seed:int=42):
10
+ """Creates an EfficientNetB2 feature extractor model and transforms.
11
+
12
+ Args:
13
+ num_classes (int, optional): number of classes in the classifier head.
14
+ Defaults to 3.
15
+ seed (int, optional): random seed value. Defaults to 42.
16
+
17
+ Returns:
18
+ model (torch.nn.Module): EffNetB2 feature extractor model.
19
+ transforms (torchvision.transforms): EffNetB2 image transforms.
20
+ """
21
+ # Create EffNetB2 pretrained weights, transforms and model
22
+
23
+ model = models.resnet50(pretrained=True)
24
+ # Freeze all layers in base model
25
+ for param in model.parameters():
26
+ param.requires_grad = False
27
+
28
+ # Change classifier head with random seed for reproducibility
29
+ torch.manual_seed(seed)
30
+ model.classifier = nn.Sequential(
31
+ nn.Dropout(p=0.3, inplace=True),
32
+ nn.Linear(in_features=1408, out_features=num_classes),
33
+ )
34
+
35
+ return model
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4