AdhamQQ commited on
Commit
d0a50cd
·
verified ·
1 Parent(s): 607cfcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -13,12 +13,10 @@ import cv2
13
  st.title("🧠 Stroke Patient Pain Intensity Detector")
14
  st.markdown("Upload a full-face image. The system will detect the affected side and use the other side to predict pain intensity.")
15
 
16
- load_model = tf.keras.models.load_model
17
-
18
  @st.cache_resource
19
  def download_models():
20
  model_urls = {
21
- "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
22
  "left_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/left_side_pain_classifier.pth",
23
  "right_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_classifier.pth",
24
  }
@@ -42,23 +40,23 @@ def download_models():
42
  f.write(r.content)
43
 
44
  # Load stroke model
45
- stroke_model = load_model("cnn_stroke_model.keras")
46
 
47
- # Load pain models with correct architecture
48
  class PainModel(nn.Module):
49
  def __init__(self):
50
  super(PainModel, self).__init__()
51
  from torchvision.models import resnet18, ResNet18_Weights
52
- self.model = resnet18(weights=ResNet18_Weights.DEFAULT)
53
- self.model.fc = nn.Linear(self.model.fc.in_features, 1)
54
 
55
  def forward(self, x):
56
- return self.model(x)
57
 
58
  left_model = PainModel()
59
  right_model = PainModel()
60
- left_model.load_state_dict(torch.load("left_side_pain_classifier.pth", map_location=torch.device("cpu")))
61
- right_model.load_state_dict(torch.load("right_side_pain_classifier.pth", map_location=torch.device("cpu")))
62
  left_model.eval()
63
  right_model.eval()
64
 
@@ -94,8 +92,7 @@ if uploaded_file is not None:
94
  left_half = full_image.crop((0, 0, mid, h))
95
  right_half = full_image.crop((mid, 0, w, h))
96
 
97
- _, H, W, C = stroke_model.input_shape
98
- stroke_input = full_image.resize((W, H))
99
  stroke_array = np.array(stroke_input).astype("float32") / 255.0
100
  stroke_array = np.expand_dims(stroke_array, axis=0)
101
 
 
13
  st.title("🧠 Stroke Patient Pain Intensity Detector")
14
  st.markdown("Upload a full-face image. The system will detect the affected side and use the other side to predict pain intensity.")
15
 
 
 
16
  @st.cache_resource
17
  def download_models():
18
  model_urls = {
19
+ "cnn_stroke_model.h5": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.h5",
20
  "left_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/left_side_pain_classifier.pth",
21
  "right_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_classifier.pth",
22
  }
 
40
  f.write(r.content)
41
 
42
  # Load stroke model
43
+ stroke_model = tf.keras.models.load_model("cnn_stroke_model.h5")
44
 
45
+ # Define corrected PainModel class
46
  class PainModel(nn.Module):
47
  def __init__(self):
48
  super(PainModel, self).__init__()
49
  from torchvision.models import resnet18, ResNet18_Weights
50
+ self.convnet = resnet18(weights=ResNet18_Weights.DEFAULT)
51
+ self.convnet.fc = nn.Linear(self.convnet.fc.in_features, 1)
52
 
53
  def forward(self, x):
54
+ return self.convnet(x)
55
 
56
  left_model = PainModel()
57
  right_model = PainModel()
58
+ left_model.load_state_dict(torch.load("left_side_pain_classifier.pth", map_location="cpu"))
59
+ right_model.load_state_dict(torch.load("right_side_pain_classifier.pth", map_location="cpu"))
60
  left_model.eval()
61
  right_model.eval()
62
 
 
92
  left_half = full_image.crop((0, 0, mid, h))
93
  right_half = full_image.crop((mid, 0, w, h))
94
 
95
+ stroke_input = full_image.resize((224, 224))
 
96
  stroke_array = np.array(stroke_input).astype("float32") / 255.0
97
  stroke_array = np.expand_dims(stroke_array, axis=0)
98