|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
from model import DenseShuffleGCANet, extract_handcrafted_features |
|
|
|
|
|
|
|
|
CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"] |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model = DenseShuffleGCANet(num_classes=4, handcrafted_feature_dim=41) |
|
|
model.load_state_dict( |
|
|
torch.load("best_model_2.pth", map_location=device) |
|
|
) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
def predict(image: Image.Image): |
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
features = extract_handcrafted_features(np.array(image)) |
|
|
features = features.unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor, features) |
|
|
probs = F.softmax(outputs, dim=1)[0] |
|
|
|
|
|
result = {CLASSES[i]: float(probs[i]) for i in range(4)} |
|
|
predicted_class = CLASSES[int(torch.argmax(probs))] |
|
|
|
|
|
return result, predicted_class |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil", label="Upload DFU Image"), |
|
|
outputs=[ |
|
|
gr.Label(num_top_classes=4, label="Probabilities"), |
|
|
gr.Textbox(label="Predicted Class") |
|
|
], |
|
|
title="DFU Classification System", |
|
|
description="Classifies diabetic foot images into NONE, INFECTION, ISCHAEMIA, or BOTH." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch() |
|
|
|
|
|
|
|
|
|