|
|
import torch |
|
|
from torch import nn |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
model = models.resnet50(pretrained=True) |
|
|
model.fc = nn.Linear(model.fc.in_features, 3) |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
def classify_thyroid_condition(image): |
|
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
image = transform(image).unsqueeze(0) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
output = model(image) |
|
|
_, predicted = torch.max(output, 1) |
|
|
|
|
|
|
|
|
if predicted.item() == 0: |
|
|
diagnosis = "Normal" |
|
|
elif predicted.item() == 1: |
|
|
diagnosis = "Hypothyroidism" |
|
|
else: |
|
|
diagnosis = "Hyperthyroidism" |
|
|
|
|
|
return diagnosis |
|
|
|
|
|
|
|
|
gr.Interface(fn=classify_thyroid_condition, inputs="image", outputs="text").launch() |
|
|
|