import gradio as gr import cv2 import requests import os from PIL import Image import timm import torch from torchvision.transforms import transforms import numpy as np from PIL import ImageFile import matplotlib.pyplot as plt import warnings import glob warnings.filterwarnings("ignore") ImageFile.LOAD_TRUNCATED_IMAGES = True def predict(image, model, device, class_name): prediction_transform = transforms.Compose([transforms.Resize(size=(224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) try: image = prediction_transform(image)[:3,:,:].unsqueeze(0) except: image = image.convert('RGB') image = prediction_transform(image)[:3,:,:].unsqueeze(0) if device == 'cuda': if torch.cuda.is_available(): image = image.cuda() else: print("You don't have cuda") with torch.no_grad(): model.eval() pred = model(image) idx = torch.argmax(pred) prob = pred[0][idx].item()*100 return prob, class_name[idx] model = timm.create_model('resnet50', pretrained=True) model.fc = torch.nn.Sequential(torch.nn.Linear(2048, 256), torch.nn.Dropout(0.2), torch.nn.ReLU(), torch.nn.Linear(256, 64), torch.nn.Dropout(0.2), torch.nn.ReLU(), torch.nn.Linear(64, 32), torch.nn.Dropout(0.2), torch.nn.ReLU(), torch.nn.Linear(32, 4), torch.nn.Softmax() ) model.load_state_dict(torch.load('model_ResNet50_acc_max.pt',map_location=torch.device('cpu'))) display_prob = True show=True #path = glob.glob('*.png') def show_preds_image(path): #for image in path: img = Image.open(path) # if show: # plt.imshow(img) # plt.show() #img = cv2.imread(path) class_name = ['adenocarcinoma', 'large.cell.carcinoma', 'normal', 'squamous.cell.carcinoma'] prob, result = predict(img, model, 'cpu', class_name) if display_prob: print('Probability of {} : {:.6f}'.format(result, prob)) return result, prob inputs_image = [ gr.components.Image(type="filepath", label="Input Image"), ] interface_image = gr.Interface( fn=show_preds_image, inputs=inputs_image, outputs="text", title="Cancer Detector App using data from Kaggle", cache_examples=False, ).launch()