DFU / app.py
EngReem85's picture
Update app.py
f52464c verified
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
from PIL import Image
from torchvision import transforms
# استيراد النموذج الحقيقي
from model import DenseShuffleGCANet, extract_handcrafted_features
# الفئات
CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]
# الجهاز
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# تحميل النموذج
model = DenseShuffleGCANet(num_classes=4, handcrafted_feature_dim=41)
model.load_state_dict(
torch.load("best_model_2.pth", map_location=device)
)
model.to(device)
model.eval()
# تحويل الصورة
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# دالة التنبؤ
def predict(image: Image.Image):
image_tensor = transform(image).unsqueeze(0).to(device)
features = extract_handcrafted_features(np.array(image))
features = features.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor, features)
probs = F.softmax(outputs, dim=1)[0]
result = {CLASSES[i]: float(probs[i]) for i in range(4)}
predicted_class = CLASSES[int(torch.argmax(probs))]
return result, predicted_class
# واجهة Gradio
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload DFU Image"),
outputs=[
gr.Label(num_top_classes=4, label="Probabilities"),
gr.Textbox(label="Predicted Class")
],
title="DFU Classification System",
description="Classifies diabetic foot images into NONE, INFECTION, ISCHAEMIA, or BOTH."
)
if __name__ == "__main__":
interface.launch()