File size: 1,783 Bytes
bbbee9c
2d7d0ad
bbbee9c
2d7d0ad
 
 
 
 
 
bbbee9c
2d7d0ad
0c30c29
bbbee9c
2d7d0ad
 
bbbee9c
2d7d0ad
7ba4adc
2d7d0ad
f52464c
2d7d0ad
 
0c30c29
 
2d7d0ad
 
 
 
 
 
 
 
 
 
 
 
 
bbbee9c
2d7d0ad
 
bbbee9c
 
 
2d7d0ad
bbbee9c
2d7d0ad
 
bbbee9c
2d7d0ad
 
 
bbbee9c
2d7d0ad
 
 
 
 
 
 
 
bbbee9c
 
 
 
0c30c29
2d7d0ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
from PIL import Image
from torchvision import transforms

# استيراد النموذج الحقيقي
from model import DenseShuffleGCANet, extract_handcrafted_features

# الفئات
CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]

# الجهاز
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# تحميل النموذج
model = DenseShuffleGCANet(num_classes=4, handcrafted_feature_dim=41)
model.load_state_dict(
    torch.load("best_model_2.pth", map_location=device)
)
model.to(device)
model.eval()

# تحويل الصورة
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# دالة التنبؤ
def predict(image: Image.Image):
    image_tensor = transform(image).unsqueeze(0).to(device)

    features = extract_handcrafted_features(np.array(image))
    features = features.unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image_tensor, features)
        probs = F.softmax(outputs, dim=1)[0]

    result = {CLASSES[i]: float(probs[i]) for i in range(4)}
    predicted_class = CLASSES[int(torch.argmax(probs))]

    return result, predicted_class

# واجهة Gradio
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload DFU Image"),
    outputs=[
        gr.Label(num_top_classes=4, label="Probabilities"),
        gr.Textbox(label="Predicted Class")
    ],
    title="DFU Classification System",
    description="Classifies diabetic foot images into NONE, INFECTION, ISCHAEMIA, or BOTH."
)

if __name__ == "__main__":
    interface.launch()