AdhamQQ commited on
Commit
68b7985
Β·
verified Β·
1 Parent(s): 92935b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -117
app.py CHANGED
@@ -1,155 +1,98 @@
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
- from torchvision.models import resnet18
5
  from torchvision import transforms
6
  from PIL import Image
7
  import numpy as np
8
  from tensorflow.keras.models import load_model
9
  import os
10
  import requests
11
- import cv2
12
 
13
- st.title("🧠 Stroke Patient Pain Intensity Detector")
14
- st.markdown("Upload a full-face image. The system will detect the affected side and use the other side to predict pain intensity.")
 
15
 
16
  @st.cache_resource
17
  def download_models():
18
  model_urls = {
19
  "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
20
- "left_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/left_side_pain_classifier.pth",
21
- "right_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_classifier.pth",
22
  }
23
 
24
  for filename, url in model_urls.items():
25
- if not os.path.exists(filename) or os.path.getsize(filename) < 10000:
26
  st.write(f"πŸ“₯ Downloading {filename}...")
27
- try:
28
- r = requests.get(url)
29
- r.raise_for_status()
30
- with open(filename, "wb") as f:
31
- f.write(r.content)
32
- st.success(f"βœ… {filename} downloaded.")
33
- except Exception as e:
34
- st.error(f"❌ Could not download {filename}: {e}")
35
- st.stop()
36
  else:
37
  st.write(f"βœ”οΈ {filename} already exists.")
38
 
39
- # Haar cascade
40
- haar_path = "haarcascade_frontalface_default.xml"
41
- haar_url = "https://raw.githubusercontent.com/opencv/opencv/master/data/haarcascades/haarcascade_frontalface_default.xml"
42
- if not os.path.exists(haar_path):
43
- r = requests.get(haar_url)
44
- with open(haar_path, "wb") as f:
45
- f.write(r.content)
46
-
47
- # Load stroke model
48
- try:
49
- stroke_model = load_model("cnn_stroke_model.keras")
50
- except Exception as e:
51
- st.error(f"❌ Failed to load stroke model: {e}")
52
- st.stop()
53
-
54
- def load_pain_model(path):
55
- model = resnet18(weights=None)
56
- model.fc = nn.Linear(model.fc.in_features, 1)
57
- model.load_state_dict(torch.load(path, map_location="cpu"))
58
- model.eval()
59
- return model
60
-
61
- try:
62
- left_model = load_pain_model("left_side_pain_classifier.pth")
63
- right_model = load_pain_model("right_side_pain_classifier.pth")
64
- except Exception as e:
65
- st.error(f"❌ Error loading PyTorch pain models: {e}")
66
- st.stop()
67
-
68
- return stroke_model, left_model, right_model
69
-
70
- stroke_model, left_model, right_model = download_models()
71
 
72
  transform = transforms.Compose([
73
  transforms.Resize((224, 224)),
74
  transforms.ToTensor(),
75
- transforms.Normalize([0.485, 0.456, 0.406],
76
- [0.229, 0.224, 0.225])
77
  ])
78
 
79
- uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])
80
 
81
  if uploaded_file is not None:
 
82
  full_image = Image.open(uploaded_file).convert("RGB")
83
- st.image(full_image, caption="πŸ“Έ Uploaded Full-Face Image", use_column_width=True)
84
-
85
- np_img = np.array(full_image)
86
- gray = cv2.cvtColor(np_img, cv2.COLOR_RGB2GRAY)
87
- face_cascade = cv2.CascadeClassifier("haarcascade_frontalface_default.xml")
88
- faces = face_cascade.detectMultiScale(gray, 1.1, 5)
89
-
90
- if len(faces) == 0:
91
- st.error("❌ No face detected. Please upload a clearer image.")
92
- st.stop()
93
-
94
- # Use first detected face
95
- (x, y, w, h) = faces[0]
96
- center_x = x + w // 2
97
- half_width = w // 2
98
-
99
- left_half_box = (
100
- max(center_x - half_width, 0),
101
- y,
102
- center_x,
103
- y + h
104
- )
105
- right_half_box = (
106
- center_x,
107
- y,
108
- min(center_x + half_width, full_image.width),
109
- y + h
110
- )
111
-
112
- left_half = full_image.crop(left_half_box)
113
- right_half = full_image.crop(right_half_box)
114
-
115
- stroke_input = full_image.resize((224, 224))
116
- stroke_array = np.expand_dims(np.array(stroke_input) / 255.0, axis=0)
117
-
118
- st.write("🧠 Predicting affected side...")
119
- stroke_pred = stroke_model.predict(stroke_array)
120
- affected = int(np.round(stroke_pred[0][0]))
121
 
122
- # Flip logic to match face's own point of view
123
- unaffected_face = left_half if affected == 1 else right_half
124
- selected_model = left_model if affected == 1 else right_model
 
125
 
126
- st.write("πŸ“ˆ Predicting pain...")
127
- with torch.no_grad():
128
- tensor = transform(unaffected_face).unsqueeze(0)
129
- output = selected_model(tensor)
130
- prob = torch.sigmoid(output).item()
131
- label = 1 if prob > 0.5 else 0
132
-
133
- st.subheader("πŸ” Prediction Result")
134
- st.image(unaffected_face, caption="🧍 Unaffected Side Used", width=300)
135
-
136
- # βœ… Fix: if 1 means left is affected, use right side for pain
137
- if affected == 1:
138
- affected_side = "left"
139
- unaffected_side = "right"
140
- unaffected_face = right_half
141
- selected_model = right_model
142
- else:
143
- affected_side = "right"
144
- unaffected_side = "left"
145
- unaffected_face = left_half
146
- selected_model = left_model
147
 
148
- st.write(f"**Affected Side (Face POV): **{face_affected}**")
149
- st.write(f"**Unaffected Side (Face POV): **{face_unaffected}**")
150
- st.write(f"**Predicted PSPI Pain Score:** {prob:.3f}")
151
- st.write(f"**Stroke model raw output:** {stroke_pred[0][0]}")
152
 
 
 
 
153
 
154
 
 
 
155
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  import numpy as np
7
  from tensorflow.keras.models import load_model
8
  import os
9
  import requests
 
10
 
11
+ st.title("Stroke Patient Pain Intensity Detector")
12
+ st.markdown("Upload a full-face image of a stroke patient. The app will detect the affected side and predict pain intensity using the unaffected side.")
13
+ st.write("πŸ”§ App started. Preparing to download models...")
14
 
15
  @st.cache_resource
16
  def download_models():
17
  model_urls = {
18
  "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
19
+ "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
 
20
  }
21
 
22
  for filename, url in model_urls.items():
23
+ if not os.path.exists(filename):
24
  st.write(f"πŸ“₯ Downloading {filename}...")
25
+ r = requests.get(url)
26
+ with open(filename, "wb") as f:
27
+ f.write(r.content)
28
+ st.write(f"βœ… {filename} downloaded.")
 
 
 
 
 
29
  else:
30
  st.write(f"βœ”οΈ {filename} already exists.")
31
 
32
+ st.write("πŸ“¦ Loading models...")
33
+ stroke_model = load_model("cnn_stroke_model.keras")
34
+
35
+ class PainRegressor(nn.Module):
36
+ def __init__(self):
37
+ super(PainRegressor, self).__init__()
38
+ from torchvision.models import resnet18, ResNet18_Weights
39
+ self.base = resnet18(weights=ResNet18_Weights.DEFAULT)
40
+ num_features = self.base.fc.in_features
41
+ self.base.fc = nn.Linear(num_features, 1)
42
+ def forward(self, x):
43
+ return self.base(x)
44
+
45
+ pain_model = PainRegressor()
46
+ pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device('cpu')))
47
+ pain_model.eval()
48
+ st.write("βœ… Models loaded.")
49
+
50
+ return stroke_model, pain_model
51
+
52
+ stroke_model, pain_model = download_models()
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  transform = transforms.Compose([
55
  transforms.Resize((224, 224)),
56
  transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
58
+ std=[0.229, 0.224, 0.225])
59
  ])
60
 
61
+ uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
62
 
63
  if uploaded_file is not None:
64
+ st.write("πŸ“· Image uploaded. Processing...")
65
  full_image = Image.open(uploaded_file).convert("RGB")
66
+ st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ w, h = full_image.size
69
+ mid = w // 2
70
+ left_face = full_image.crop((0, 0, mid, h))
71
+ right_face = full_image.crop((mid, 0, w, h))
72
 
73
+ # πŸ” Automatically resize image based on model input
74
+ _, H, W, C = stroke_model.input_shape
75
+ st.write(f"πŸ” Resizing uploaded image to: ({H}, {W}) for stroke model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ stroke_input = full_image.resize((W, H))
78
+ stroke_array = np.array(stroke_input).astype("float32") / 255.0
79
+ stroke_array = np.expand_dims(stroke_array, axis=0)
 
80
 
81
+ st.write("🧠 Running stroke model prediction...")
82
+ stroke_pred = stroke_model.predict(stroke_array)
83
+ affected = int(np.round(stroke_pred[0][0]))
84
 
85
 
86
+ unaffected_face = right_face if affected == 0 else left_face
87
+ unaffected_tensor = transform(unaffected_face).unsqueeze(0)
88
 
89
+ st.write("πŸ“ˆ Predicting pain score...")
90
+ with torch.no_grad():
91
+ output = pain_model(unaffected_tensor)
92
+ pspi_score = output.item()
93
+
94
+ st.subheader("Prediction Results")
95
+ st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
96
+ st.write(f"**Affected side:** {'left' if affected == 0 else 'right'}")
97
+ st.write(f"**Unaffected side:** {'right' if affected == 0 else 'left'}")
98
+ st.write(f"**Predicted PSPI Pain Score:** {round(pspi_score, 3)}")