EngReem85 commited on
Commit
0c30c29
·
verified ·
1 Parent(s): 4ad0372

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -15
app.py CHANGED
@@ -1,31 +1,54 @@
1
  import torch
2
- from PIL import Image
3
  import torchvision.transforms as transforms
 
 
4
  import gradio as gr
5
 
6
  # -------------------------------
7
- # 1️⃣ تحميل النموذج المحفوظ كاملًا
8
  # -------------------------------
9
- # تأكدي أن best_model_2.pth تم حفظه باستخدام torch.save(model, "best_model_2.pth")
10
- model = torch.load("best_model_2.pth", map_location=torch.device('cpu'))
11
- model.eval()
12
 
13
  # -------------------------------
14
- # 2️⃣ إعداد الفئات
15
  # -------------------------------
16
- CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # -------------------------------
19
- # 3️⃣ دالة استخراج الخصائص اليدوية
 
 
 
 
 
 
 
20
  # -------------------------------
21
  def extract_handcrafted_features(image_array):
22
- # مثال عشوائي لتوضيح الفكرة، عدلي حسب ما كنتِ تستخدمينه
23
- import numpy as np
24
- features = np.random.rand(49).astype(np.float32)
25
  return torch.tensor(features)
26
 
27
  # -------------------------------
28
- # 4️⃣ دالة التنبؤ
29
  # -------------------------------
30
  def predict_image(image: Image.Image):
31
  transform = transforms.Compose([
@@ -34,8 +57,6 @@ def predict_image(image: Image.Image):
34
  transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
35
  ])
36
  image_tensor = transform(image).unsqueeze(0)
37
-
38
- # استخراج الخصائص اليدوية
39
  features = extract_handcrafted_features(np.array(image)).unsqueeze(0)
40
 
41
  with torch.no_grad():
@@ -46,7 +67,7 @@ def predict_image(image: Image.Image):
46
  return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}, pred_class
47
 
48
  # -------------------------------
49
- # 5️⃣ واجهة Gradio
50
  # -------------------------------
51
  interface = gr.Interface(
52
  fn=predict_image,
@@ -58,3 +79,4 @@ interface = gr.Interface(
58
 
59
  if __name__ == "__main__":
60
  interface.launch()
 
 
1
  import torch
2
+ import torch.nn as nn
3
  import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import numpy as np
6
  import gradio as gr
7
 
8
  # -------------------------------
9
+ # 1️⃣ إعداد الفئات
10
  # -------------------------------
11
+ CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]
 
 
12
 
13
  # -------------------------------
14
+ # 2️⃣ تعريف نموذج DenseShuffleGCANet
15
  # -------------------------------
16
+ class DenseShuffleGCANet(nn.Module):
17
+ def __init__(self, num_classes=4, handcrafted_feature_dim=49):
18
+ super(DenseShuffleGCANet, self).__init__()
19
+ # مثال على backbone، عدلي حسب الكود الأصلي
20
+ self.backbone = nn.Sequential(
21
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
22
+ nn.ReLU(),
23
+ nn.AdaptiveAvgPool2d((1,1))
24
+ )
25
+ self.fc_handcrafted = nn.Linear(handcrafted_feature_dim, 32)
26
+ self.classifier = nn.Linear(64 + 32, num_classes)
27
+
28
+ def forward(self, x_image, x_features):
29
+ x_img = self.backbone(x_image)
30
+ x_img = x_img.view(x_img.size(0), -1)
31
+ x_feat = self.fc_handcrafted(x_features)
32
+ x = torch.cat([x_img, x_feat], dim=1)
33
+ out = self.classifier(x)
34
+ return out
35
 
36
  # -------------------------------
37
+ # 3️⃣ تحميل النموذج مع weights
38
+ # -------------------------------
39
+ model = DenseShuffleGCANet(num_classes=4, handcrafted_feature_dim=49)
40
+ model.load_state_dict(torch.load("best_model_2.pth", map_location=torch.device('cpu')))
41
+ model.eval()
42
+
43
+ # -------------------------------
44
+ # 4️⃣ دالة استخراج الخصائص اليدوية
45
  # -------------------------------
46
  def extract_handcrafted_features(image_array):
47
+ features = np.random.rand(49).astype(np.float32) # عدلي حسب خصائصك الحقيقية
 
 
48
  return torch.tensor(features)
49
 
50
  # -------------------------------
51
+ # 5️⃣ دالة التنبؤ
52
  # -------------------------------
53
  def predict_image(image: Image.Image):
54
  transform = transforms.Compose([
 
57
  transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
58
  ])
59
  image_tensor = transform(image).unsqueeze(0)
 
 
60
  features = extract_handcrafted_features(np.array(image)).unsqueeze(0)
61
 
62
  with torch.no_grad():
 
67
  return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}, pred_class
68
 
69
  # -------------------------------
70
+ # 6️⃣ واجهة Gradio
71
  # -------------------------------
72
  interface = gr.Interface(
73
  fn=predict_image,
 
79
 
80
  if __name__ == "__main__":
81
  interface.launch()
82
+