AdhamQQ commited on
Commit
9745932
Β·
verified Β·
1 Parent(s): 4691c8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -119
app.py CHANGED
@@ -8,16 +8,23 @@ from tensorflow.keras.models import load_model
8
  from torchvision.models import resnet18
9
  import os
10
  import requests
 
11
  import cv2
12
 
13
- # Title
14
  st.title("🧠 Stroke Patient Pain Intensity Detector")
15
 
16
- st.markdown("""
17
- Upload a full-face image of a stroke patient. The app detects the **affected facial side** using a stroke classification model, then uses the **unaffected side** to predict **pain intensity (PSPI score)**.
18
- """)
19
-
20
- # Download models
 
 
 
 
 
 
21
  @st.cache_resource
22
  def download_models():
23
  model_urls = {
@@ -26,134 +33,117 @@ def download_models():
26
  }
27
  for filename, url in model_urls.items():
28
  if not os.path.exists(filename):
 
29
  r = requests.get(url)
30
  with open(filename, "wb") as f:
31
  f.write(r.content)
32
- if not os.path.exists("haarcascade_frontalface_default.xml"):
33
- r = requests.get("https://raw.githubusercontent.com/opencv/opencv/master/data/haarcascades/haarcascade_frontalface_default.xml")
34
- with open("haarcascade_frontalface_default.xml", "wb") as f:
35
- f.write(r.content)
36
  stroke_model = load_model("cnn_stroke_model.keras")
37
  pain_model = resnet18(weights=None)
38
  pain_model.fc = nn.Linear(pain_model.fc.in_features, 1)
39
  pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
40
  pain_model.eval()
 
41
  return stroke_model, pain_model
42
 
43
  stroke_model, pain_model = download_models()
44
 
45
- # Rotation using eyes
46
- def auto_rotate_face(image, box):
47
- img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
48
- x, y, w, h = box
49
- roi = img_cv[y:y+h, x:x+w]
50
- gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
51
- eye_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_eye.xml")
52
- eyes = eye_cascade.detectMultiScale(gray, 1.1, 5)
53
- if len(eyes) >= 2:
54
- eyes = sorted(eyes[:2], key=lambda e: e[0])
55
- (x1, y1, w1, h1), (x2, y2, w2, h2) = eyes
56
- center1 = (x1 + w1 // 2, y1 + h1 // 2)
57
- center2 = (x2 + w2 // 2, y2 + h2 // 2)
58
- dx, dy = center2[0] - center1[0], center2[1] - center1[1]
59
- angle = np.degrees(np.arctan2(dy, dx))
60
- return image.rotate(-angle, center=(x + w // 2, y + h // 2), expand=True)
61
- return image
62
-
63
- # Image transform
64
  transform = transforms.Compose([
65
  transforms.Resize((224, 224)),
66
  transforms.ToTensor(),
67
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
68
  ])
69
 
70
- uploaded = st.file_uploader("πŸ“‚ Upload face image", type=["jpg", "jpeg", "png"])
71
- if uploaded:
72
- full_image = Image.open(uploaded).convert("RGB")
73
- st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
74
-
75
- # Detect face
76
- cv_image = cv2.cvtColor(np.array(full_image), cv2.COLOR_RGB2BGR)
77
- face_cascade = cv2.CascadeClassifier("haarcascade_frontalface_default.xml")
78
- faces = face_cascade.detectMultiScale(cv_image, 1.1, 5)
79
-
80
- if len(faces) == 0:
81
- st.error("❌ No face detected. Please upload a clear frontal face image.")
82
- st.stop()
83
-
84
- # Rotate and redetect
85
- x, y, w, h = faces[0]
86
- rotated = auto_rotate_face(full_image, (x, y, w, h))
87
- rotated_cv = cv2.cvtColor(np.array(rotated), cv2.COLOR_RGB2BGR)
88
- faces2 = face_cascade.detectMultiScale(rotated_cv, 1.1, 5)
89
-
90
- if len(faces2) == 0:
91
- st.warning("⚠️ Redetection failed after rotation, using original box.")
92
- x, y, w, h = faces[0]
93
- else:
94
- x, y, w, h = faces2[0]
95
-
96
- # Add padding
97
- pad = int(0.2 * max(w, h))
98
- x1, y1 = max(0, x - pad), max(0, y - pad)
99
- x2, y2 = min(rotated.width, x + w + pad), min(rotated.height, y + h + pad)
100
- face_crop = rotated.crop((x1, y1, x2, y2))
101
-
102
- # Split face
103
- fw, fh = face_crop.size
104
- mid = fw // 2
105
- patient_right = face_crop.crop((0, 0, mid, fh)) # viewer's left
106
- patient_left = face_crop.crop((mid, 0, fw, fh)) # viewer's right
107
-
108
- # Stroke model input
109
- _, H, W, C = stroke_model.input_shape
110
- stroke_input = face_crop.resize((W, H))
111
- stroke_array = np.array(stroke_input).astype("float32") / 255.0
112
- stroke_array = np.expand_dims(stroke_array, axis=0)
113
-
114
- # Predict affected side
115
- st.write("🧠 Predicting affected side...")
116
- stroke_pred = stroke_model.predict(stroke_array)
117
- stroke_raw = stroke_pred[0][0]
118
- affected = int(np.round(stroke_raw)) # 0 = left, 1 = right
119
-
120
- # Assign faces
121
- if affected == 0:
122
- affected_side = "left"
123
- unaffected_side = "right"
124
- unaffected_face = patient_right
125
- else:
126
- affected_side = "right"
127
- unaffected_side = "left"
128
- unaffected_face = patient_left
129
-
130
- # Predict pain score
131
- st.write("πŸ“ˆ Predicting PSPI pain score...")
132
- tensor = transform(unaffected_face).unsqueeze(0)
133
- with torch.no_grad():
134
- output = pain_model(tensor)
135
- raw_score = output.item()
136
- pspi_score = max(0.0, min(raw_score, 6.0))
137
-
138
- # Output
139
- st.subheader("πŸ” Prediction Results")
140
- st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
141
- st.write(f"**🧭 Affected Side (face POV):** `{affected_side}`")
142
- st.write(f"**βœ… Unaffected Side (face POV):** `{unaffected_side}`")
143
- st.write(f"**🎯 Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
144
- st.write(f"**πŸ“ˆ Raw Pain Model Output:** `{raw_score:.3f}`")
145
- st.write(f"**πŸ“Š Stroke Model Raw Output:** `{stroke_raw:.4f}`")
146
-
147
- st.markdown("""
148
- ---
149
- ### ℹ️ Stroke Model Output
150
- - Output ∈ [0, 1]
151
- - Closer to `0` = Left side affected
152
- - Closer to `1` = Right side affected
153
-
154
- ### ℹ️ PSPI Score (0–6)
155
- - 0: No pain
156
- - 1–2: Mild
157
- - 3–4: Moderate
158
- - 5–6: Severe
159
- """)
 
8
  from torchvision.models import resnet18
9
  import os
10
  import requests
11
+ import mediapipe as mp
12
  import cv2
13
 
14
+ # App title
15
  st.title("🧠 Stroke Patient Pain Intensity Detector")
16
 
17
+ # Instructions
18
+ st.markdown(
19
+ """
20
+ Upload a full-face image of a stroke patient.
21
+ The app will detect the **affected facial side** using a stroke classification model,
22
+ and then use the **unaffected side** to predict **pain intensity** (PSPI score).
23
+ """
24
+ )
25
+ st.write("πŸ”§ Initializing and downloading models...")
26
+
27
+ # Download and load models
28
  @st.cache_resource
29
  def download_models():
30
  model_urls = {
 
33
  }
34
  for filename, url in model_urls.items():
35
  if not os.path.exists(filename):
36
+ st.write(f"πŸ“₯ Downloading {filename}...")
37
  r = requests.get(url)
38
  with open(filename, "wb") as f:
39
  f.write(r.content)
40
+ st.success(f"βœ… {filename} downloaded.")
41
+ else:
42
+ st.write(f"βœ”οΈ {filename} already exists.")
43
+
44
  stroke_model = load_model("cnn_stroke_model.keras")
45
  pain_model = resnet18(weights=None)
46
  pain_model.fc = nn.Linear(pain_model.fc.in_features, 1)
47
  pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
48
  pain_model.eval()
49
+
50
  return stroke_model, pain_model
51
 
52
  stroke_model, pain_model = download_models()
53
 
54
+ # Preprocessing for pain model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  transform = transforms.Compose([
56
  transforms.Resize((224, 224)),
57
  transforms.ToTensor(),
58
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
  ])
60
 
61
+ # MediaPipe Face Detection
62
+ mp_face = mp.solutions.face_detection
63
+ mp_draw = mp.solutions.drawing_utils
64
+
65
+ # Upload UI
66
+ uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])
67
+
68
+ if uploaded_file is not None:
69
+ st.write("πŸ“· Image uploaded. Detecting face...")
70
+ full_image = Image.open(uploaded_file).convert("RGB")
71
+ img_np = np.array(full_image)
72
+
73
+ with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.6) as detector:
74
+ results = detector.process(cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR))
75
+
76
+ if not results.detections:
77
+ st.error("❌ No face detected. Please upload a clear frontal face image.")
78
+ st.stop()
79
+
80
+ # Use first detection
81
+ detection = results.detections[0]
82
+ bboxC = detection.location_data.relative_bounding_box
83
+ ih, iw, _ = img_np.shape
84
+ x = int(bboxC.xmin * iw)
85
+ y = int(bboxC.ymin * ih)
86
+ w = int(bboxC.width * iw)
87
+ h = int(bboxC.height * ih)
88
+
89
+ face_crop = full_image.crop((x, y, x + w, y + h))
90
+ st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
91
+
92
+ # Split halves (face POV)
93
+ fw, fh = face_crop.size
94
+ fmid = fw // 2
95
+ patient_right = face_crop.crop((0, 0, fmid, fh)) # viewer's left
96
+ patient_left = face_crop.crop((fmid, 0, fw, fh)) # viewer's right
97
+
98
+ # Stroke prediction input
99
+ _, H, W, C = stroke_model.input_shape
100
+ stroke_input = face_crop.resize((W, H))
101
+ stroke_array = np.array(stroke_input).astype("float32") / 255.0
102
+ stroke_array = np.expand_dims(stroke_array, axis=0)
103
+
104
+ st.write("🧠 Predicting affected side of the face...")
105
+ stroke_pred = stroke_model.predict(stroke_array)
106
+ stroke_raw = stroke_pred[0][0]
107
+ affected = int(np.round(stroke_raw)) # 0 = left affected, 1 = right affected
108
+
109
+ if affected == 0:
110
+ affected_side = "left"
111
+ unaffected_side = "right"
112
+ unaffected_face = patient_right
113
+ else:
114
+ affected_side = "right"
115
+ unaffected_side = "left"
116
+ unaffected_face = patient_left
117
+
118
+ # Pain prediction
119
+ st.write("πŸ“ˆ Predicting PSPI pain score from unaffected side...")
120
+ input_tensor = transform(unaffected_face).unsqueeze(0)
121
+ with torch.no_grad():
122
+ output = pain_model(input_tensor)
123
+ raw_score = output.item()
124
+ pspi_score = max(0.0, min(raw_score, 6.0))
125
+
126
+ # Display results
127
+ st.subheader("πŸ” Prediction Results")
128
+ st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
129
+ st.write(f"**🧭 Affected Side (face POV):** `{affected_side}`")
130
+ st.write(f"**βœ… Unaffected Side (face POV):** `{unaffected_side}`")
131
+ st.write(f"**🎯 Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
132
+ st.write(f"**πŸ“ˆ Raw Pain Model Output:** `{raw_score:.3f}`")
133
+ st.write(f"**πŸ“Š Stroke Model Raw Output:** `{stroke_raw:.4f}`")
134
+
135
+ st.markdown(
136
+ """
137
+ ---
138
+ ### ℹ️ Stroke Model Output
139
+ - Output is between `0` and `1`
140
+ - Closer to `0` = Left side is affected
141
+ - Closer to `1` = Right side is affected
142
+
143
+ ### ℹ️ PSPI Score Scale
144
+ - `0`: No pain
145
+ - `1–2`: Mild pain
146
+ - `3–4`: Moderate pain
147
+ - `5–6`: Severe pain
148
+ """
149
+ )