| import gradio as gr |
| import cv2 |
| import requests |
|
|
| import os |
| from PIL import Image |
| import timm |
| import torch |
| from torchvision.transforms import transforms |
| import numpy as np |
| from PIL import ImageFile |
| import matplotlib.pyplot as plt |
| import warnings |
| import glob |
|
|
|
|
| warnings.filterwarnings("ignore") |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
| def predict(image, model, device, class_name): |
|
|
| prediction_transform = transforms.Compose([transforms.Resize(size=(224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) |
| try: |
| image = prediction_transform(image)[:3,:,:].unsqueeze(0) |
| except: |
| image = image.convert('RGB') |
| image = prediction_transform(image)[:3,:,:].unsqueeze(0) |
|
|
| if device == 'cuda': |
| if torch.cuda.is_available(): |
| image = image.cuda() |
| else: |
| print("You don't have cuda") |
|
|
| with torch.no_grad(): |
| model.eval() |
| pred = model(image) |
|
|
|
|
| idx = torch.argmax(pred) |
|
|
| prob = pred[0][idx].item()*100 |
|
|
| return prob, class_name[idx] |
|
|
|
|
| model = timm.create_model('resnet50', pretrained=True) |
|
|
| model.fc = torch.nn.Sequential(torch.nn.Linear(2048, 256), |
| torch.nn.Dropout(0.2), |
| torch.nn.ReLU(), |
| torch.nn.Linear(256, 64), |
| torch.nn.Dropout(0.2), |
| torch.nn.ReLU(), |
| torch.nn.Linear(64, 32), |
| torch.nn.Dropout(0.2), |
| torch.nn.ReLU(), |
| torch.nn.Linear(32, 4), |
| torch.nn.Softmax() |
| ) |
|
|
| model.load_state_dict(torch.load('model_ResNet50_acc_max.pt',map_location=torch.device('cpu'))) |
|
|
| display_prob = True |
| show=True |
| |
|
|
| def show_preds_image(path): |
| |
| img = Image.open(path) |
| |
| |
| |
| |
| class_name = ['adenocarcinoma', |
| 'large.cell.carcinoma', |
| 'normal', |
| 'squamous.cell.carcinoma'] |
| prob, result = predict(img, model, 'cpu', class_name) |
| if display_prob: |
| print('Probability of {} : {:.6f}'.format(result, prob)) |
| |
| return result, prob |
| |
| inputs_image = [ |
| gr.components.Image(type="filepath", label="Input Image"), |
| ] |
|
|
| interface_image = gr.Interface( |
| fn=show_preds_image, |
| inputs=inputs_image, |
| outputs="text", |
| title="Cancer Detector App using data from Kaggle", |
| cache_examples=False, |
| ).launch() |
|
|
|
|