| | |
| | import gradio as gr |
| | import os |
| | import torch |
| |
|
| | from model import create_resnet_model, create_custom_model |
| | from timeit import default_timer as timer |
| | import torchvision |
| | import torchvision.transforms as transforms |
| |
|
| | transformer = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(256), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| |
|
| | model_name = 'resnet' |
| | |
| |
|
| | if model_name == 'custom': |
| | |
| | model = create_custom_model() |
| |
|
| | |
| | model.load_state_dict( |
| | torch.load( |
| | f="./cnn-custom-model-version-4.pt", |
| | map_location=torch.device("cpu"), |
| | ) |
| | ) |
| | elif model_name == 'resnet': |
| | model = create_resnet_model() |
| |
|
| | |
| | model.load_state_dict( |
| | torch.load( |
| | f="./cnn-resnet-version-1.pt", |
| | map_location=torch.device("cpu"), |
| | ) |
| | ) |
| | |
| |
|
| |
|
| | |
| |
|
| | def predict(img): |
| | """Transforms and performs a prediction on img and returns prediction and time taken. |
| | """ |
| | |
| | img = transformer(img).unsqueeze(0) |
| | |
| | |
| | model.eval() |
| | with torch.inference_mode(): |
| | |
| | pred_prob = torch.sigmoid(model(img)) |
| | |
| | pred_probs = {'Covid' : float(pred_prob), 'Non Covid' : (1-float(pred_prob))} |
| |
|
| | |
| | return pred_probs |
| |
|
| | |
| |
|
| |
|
| | |
| | title = "Corona Prediction" |
| | description = "A Convolutional Neural Network To classify whether a person have Corona or not using CT Scans." |
| | article = "Created by Thenujan Nagaratnam for DNN module at UoM" |
| |
|
| | |
| | example_list = [["examples/" + example] for example in os.listdir("examples")] |
| |
|
| | |
| | demo = gr.Interface(fn=predict, |
| | inputs=gr.Image(type="pil"), |
| | outputs=[gr.Label(num_top_classes=2, label="Predictions")], |
| | examples=example_list, |
| | title=title, |
| | description=description, |
| | article=article) |
| |
|
| | |
| | demo.launch() |
| |
|