testdep / app.py
Sravanth Ganta
Cancer Detector AppV1
d876962
Raw
History Blame Contribute Delete
2.83 kB
import gradio as gr
import cv2
import requests
import os
from PIL import Image
import timm
import torch
from torchvision.transforms import transforms
import numpy as np
from PIL import ImageFile
import matplotlib.pyplot as plt
import warnings
import glob
warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True
def predict(image, model, device, class_name):
prediction_transform = transforms.Compose([transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
try:
image = prediction_transform(image)[:3,:,:].unsqueeze(0)
except:
image = image.convert('RGB')
image = prediction_transform(image)[:3,:,:].unsqueeze(0)
if device == 'cuda':
if torch.cuda.is_available():
image = image.cuda()
else:
print("You don't have cuda")
with torch.no_grad():
model.eval()
pred = model(image)
idx = torch.argmax(pred)
prob = pred[0][idx].item()*100
return prob, class_name[idx]
model = timm.create_model('resnet50', pretrained=True)
model.fc = torch.nn.Sequential(torch.nn.Linear(2048, 256),
torch.nn.Dropout(0.2),
torch.nn.ReLU(),
torch.nn.Linear(256, 64),
torch.nn.Dropout(0.2),
torch.nn.ReLU(),
torch.nn.Linear(64, 32),
torch.nn.Dropout(0.2),
torch.nn.ReLU(),
torch.nn.Linear(32, 4),
torch.nn.Softmax()
)
model.load_state_dict(torch.load('model_ResNet50_acc_max.pt',map_location=torch.device('cpu')))
display_prob = True
show=True
#path = glob.glob('*.png')
def show_preds_image(path):
#for image in path:
img = Image.open(path)
# if show:
# plt.imshow(img)
# plt.show()
#img = cv2.imread(path)
class_name = ['adenocarcinoma',
'large.cell.carcinoma',
'normal',
'squamous.cell.carcinoma']
prob, result = predict(img, model, 'cpu', class_name)
if display_prob:
print('Probability of {} : {:.6f}'.format(result, prob))
return result, prob
inputs_image = [
gr.components.Image(type="filepath", label="Input Image"),
]
interface_image = gr.Interface(
fn=show_preds_image,
inputs=inputs_image,
outputs="text",
title="Cancer Detector App using data from Kaggle",
cache_examples=False,
).launch()