EngReem85 commited on
Commit
2d7d0ad
·
verified ·
1 Parent(s): 7ba4adc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -61
app.py CHANGED
@@ -1,82 +1,66 @@
1
  import torch
2
- import torch.nn as nn
3
- import torchvision.transforms as transforms
4
- from PIL import Image
5
- import numpy as np
6
  import gradio as gr
 
 
 
 
 
 
7
 
8
- # -------------------------------
9
- # 1️⃣ إعداد الفئات
10
- # -------------------------------
11
  CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]
12
 
13
- # -------------------------------
14
- # 2️⃣ تعريف نموذج DenseShuffleGCANet
15
- # -------------------------------
16
- class DenseShuffleGCANet(nn.Module):
17
- def __init__(self, num_classes=4, handcrafted_feature_dim=41):
18
- super(DenseShuffleGCANet, self).__init__()
19
- # مثال على backbone، عدلي حسب الكود الأصلي
20
- self.backbone = nn.Sequential(
21
- nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
22
- nn.ReLU(),
23
- nn.AdaptiveAvgPool2d((1,1))
24
- )
25
- self.fc_handcrafted = nn.Linear(handcrafted_feature_dim, 32)
26
- self.classifier = nn.Linear(64 + 32, num_classes)
27
-
28
- def forward(self, x_image, x_features):
29
- x_img = self.backbone(x_image)
30
- x_img = x_img.view(x_img.size(0), -1)
31
- x_feat = self.fc_handcrafted(x_features)
32
- x = torch.cat([x_img, x_feat], dim=1)
33
- out = self.classifier(x)
34
- return out
35
 
36
- # -------------------------------
37
- # 3️⃣ تحميل النموذج مع weights
38
- # -------------------------------
39
  model = DenseShuffleGCANet(num_classes=4, handcrafted_feature_dim=41)
40
- model.load_state_dict(torch.load("best_model_2.pth", map_location=torch.device('cpu')))
 
 
 
41
  model.eval()
42
 
43
- # -------------------------------
44
- # 4️⃣ دالة استخراج الخصائص اليدوية
45
- # -------------------------------
46
- def extract_handcrafted_features(image_array):
47
- features = np.random.rand(41).astype(np.float32) # عدلي حسب خصائصك الحقيقية
48
- return torch.tensor(features)
 
 
 
 
 
 
 
49
 
50
- # -------------------------------
51
- # 5️⃣ دالة التنبؤ
52
- # -------------------------------
53
- def predict_image(image: Image.Image):
54
- transform = transforms.Compose([
55
- transforms.Resize((224, 224)),
56
- transforms.ToTensor(),
57
- transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
58
- ])
59
- image_tensor = transform(image).unsqueeze(0)
60
- features = extract_handcrafted_features(np.array(image)).unsqueeze(0)
61
 
62
  with torch.no_grad():
63
  outputs = model(image_tensor, features)
64
- probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
65
- pred_class = CLASSES[probs.argmax()]
66
 
67
- return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}, pred_class
 
68
 
69
- # -------------------------------
70
- # 6️⃣ واجهة Gradio
71
- # -------------------------------
72
  interface = gr.Interface(
73
- fn=predict_image,
74
- inputs=gr.Image(type="pil"),
75
- outputs=[gr.Label(num_top_classes=4), gr.Textbox(label="Predicted Class")],
76
- title="DFU Foot Ulcer Classifier",
77
- description="Upload an image of a foot ulcer to classify it as NONE, INFECTION, ISCHAEMIA, or BOTH."
 
 
 
78
  )
79
 
80
  if __name__ == "__main__":
81
  interface.launch()
82
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
 
 
 
3
  import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ # استيراد النموذج الحقيقي
9
+ from model import DenseShuffleGCANet, extract_handcrafted_features
10
 
11
+ # الفئات
 
 
12
  CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]
13
 
14
+ # الجهاز
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # تحميل النموذج
 
 
18
  model = DenseShuffleGCANet(num_classes=4, handcrafted_feature_dim=41)
19
+ model.load_state_dict(
20
+ torch.load("best_model.pth", map_location=device)
21
+ )
22
+ model.to(device)
23
  model.eval()
24
 
25
+ # تحويل الصورة
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(
30
+ mean=[0.485, 0.456, 0.406],
31
+ std=[0.229, 0.224, 0.225]
32
+ )
33
+ ])
34
+
35
+ # دالة التنبؤ
36
+ def predict(image: Image.Image):
37
+ image_tensor = transform(image).unsqueeze(0).to(device)
38
 
39
+ features = extract_handcrafted_features(np.array(image))
40
+ features = features.unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
41
 
42
  with torch.no_grad():
43
  outputs = model(image_tensor, features)
44
+ probs = F.softmax(outputs, dim=1)[0]
 
45
 
46
+ result = {CLASSES[i]: float(probs[i]) for i in range(4)}
47
+ predicted_class = CLASSES[int(torch.argmax(probs))]
48
 
49
+ return result, predicted_class
50
+
51
+ # واجهة Gradio
52
  interface = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Image(type="pil", label="Upload DFU Image"),
55
+ outputs=[
56
+ gr.Label(num_top_classes=4, label="Probabilities"),
57
+ gr.Textbox(label="Predicted Class")
58
+ ],
59
+ title="DFU Classification System",
60
+ description="Classifies diabetic foot images into NONE, INFECTION, ISCHAEMIA, or BOTH."
61
  )
62
 
63
  if __name__ == "__main__":
64
  interface.launch()
65
 
66
+