AdhamQQ commited on
Commit
0fb1aba
Β·
verified Β·
1 Parent(s): 006408c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -32
app.py CHANGED
@@ -9,12 +9,18 @@ from torchvision.models import resnet18
9
  import os
10
  import requests
11
 
12
- # UI title and instructions
13
  st.title("🧠 Stroke Patient Pain Intensity Detector")
14
- st.markdown("Upload a full-face image of a stroke patient. The app will detect the affected side and predict pain intensity using the unaffected side.")
 
 
 
 
 
 
15
  st.write("πŸ”§ App started. Preparing to download models...")
16
 
17
- # Function to download and load models with caching
18
  @st.cache_resource
19
  def download_models():
20
  model_urls = {
@@ -22,69 +28,70 @@ def download_models():
22
  "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
23
  }
24
 
25
- # Download models if not already downloaded
26
  for filename, url in model_urls.items():
27
  if not os.path.exists(filename):
28
  st.write(f"πŸ“₯ Downloading {filename}...")
29
  r = requests.get(url)
30
  with open(filename, "wb") as f:
31
  f.write(r.content)
32
- st.write(f"βœ… {filename} downloaded.")
33
  else:
34
  st.write(f"βœ”οΈ {filename} already exists.")
35
 
36
  st.write("πŸ“¦ Loading models...")
37
 
38
- # Load stroke side classification model (Keras)
39
  stroke_model = load_model("cnn_stroke_model.keras")
40
 
41
- # Load pain intensity prediction model (PyTorch ResNet18)
42
  pain_model = resnet18(weights=None)
43
- pain_model.fc = nn.Linear(pain_model.fc.in_features, 1) # One output: PSPI pain score
44
  pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
45
  pain_model.eval()
46
 
47
  return stroke_model, pain_model
48
 
49
- # Load models
50
  stroke_model, pain_model = download_models()
51
 
52
- # Define preprocessing pipeline for PyTorch pain model
53
  transform = transforms.Compose([
54
- transforms.Resize((224, 224)), # Resize to ResNet18 input
55
- transforms.ToTensor(), # Convert to tensor
56
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats
 
57
  ])
58
 
59
- # File uploader widget
60
  uploaded_file = st.file_uploader("πŸ“‚ Choose a full-face image", type=["jpg", "jpeg", "png"])
61
 
62
- # If an image is uploaded, start processing
63
  if uploaded_file is not None:
64
  st.write("πŸ“· Image uploaded. Processing...")
65
- full_image = Image.open(uploaded_file).convert("RGB") # Ensure RGB format
66
  st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
67
 
68
- # Crop left and right halves of the face
69
  w, h = full_image.size
70
  mid = w // 2
71
  left_face = full_image.crop((0, 0, mid, h))
72
  right_face = full_image.crop((mid, 0, w, h))
73
 
74
- # Resize full image for stroke model input
75
  _, H, W, C = stroke_model.input_shape
76
  st.write(f"πŸ” Resizing uploaded image to: ({H}, {W}) for stroke model")
77
  stroke_input = full_image.resize((W, H))
78
  stroke_array = np.array(stroke_input).astype("float32") / 255.0
79
- stroke_array = np.expand_dims(stroke_array, axis=0) # Shape: (1, H, W, C)
80
 
81
- # Predict affected side using stroke model
82
  st.write("🧠 Running stroke model prediction...")
83
  stroke_pred = stroke_model.predict(stroke_array)
84
- affected = int(np.round(stroke_pred[0][0])) # Output: 1 = left affected, 0 = right affected
85
 
86
- # Select unaffected side and label sides (from patient's perspective)
87
- if affected == 1:
 
 
88
  affected_side = "left"
89
  unaffected_side = "right"
90
  unaffected_face = right_face
@@ -93,19 +100,36 @@ if uploaded_file is not None:
93
  unaffected_side = "left"
94
  unaffected_face = left_face
95
 
96
- # Preprocess the unaffected side for the PyTorch model
97
- unaffected_tensor = transform(unaffected_face).unsqueeze(0) # Shape: (1, 3, 224, 224)
98
 
99
- # Predict pain score
100
  st.write("πŸ“ˆ Predicting pain score...")
101
  with torch.no_grad():
102
  output = pain_model(unaffected_tensor)
103
- pspi_score = output.item() # Extract float value
104
 
105
- # Display results
106
- st.subheader("🧾 Prediction Results")
107
  st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
108
- st.write(f"**🧭 Affected side (face POV):** `{affected_side}`")
109
- st.write(f"**βœ… Unaffected side (face POV):** `{unaffected_side}`")
 
110
  st.write(f"**🎯 Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
111
- st.write(f"**πŸ“Š Stroke model raw output:** `{stroke_pred[0][0]}`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import os
10
  import requests
11
 
12
+ # App title and instructions
13
  st.title("🧠 Stroke Patient Pain Intensity Detector")
14
+ st.markdown(
15
+ """
16
+ Upload a full-face image of a stroke patient.
17
+ The app will detect the **affected facial side** using a stroke classification model,
18
+ and then use the **unaffected side** to predict **pain intensity**.
19
+ """
20
+ )
21
  st.write("πŸ”§ App started. Preparing to download models...")
22
 
23
+ # Function to download and load models
24
  @st.cache_resource
25
  def download_models():
26
  model_urls = {
 
28
  "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
29
  }
30
 
 
31
  for filename, url in model_urls.items():
32
  if not os.path.exists(filename):
33
  st.write(f"πŸ“₯ Downloading {filename}...")
34
  r = requests.get(url)
35
  with open(filename, "wb") as f:
36
  f.write(r.content)
37
+ st.success(f"βœ… {filename} downloaded.")
38
  else:
39
  st.write(f"βœ”οΈ {filename} already exists.")
40
 
41
  st.write("πŸ“¦ Loading models...")
42
 
43
+ # Load stroke side detection model (Keras)
44
  stroke_model = load_model("cnn_stroke_model.keras")
45
 
46
+ # Load pain intensity model (PyTorch ResNet18)
47
  pain_model = resnet18(weights=None)
48
+ pain_model.fc = nn.Linear(pain_model.fc.in_features, 1)
49
  pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
50
  pain_model.eval()
51
 
52
  return stroke_model, pain_model
53
 
54
+ # Download and load models
55
  stroke_model, pain_model = download_models()
56
 
57
+ # Image preprocessing for pain model (ResNet)
58
  transform = transforms.Compose([
59
+ transforms.Resize((224, 224)),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
62
+ std=[0.229, 0.224, 0.225])
63
  ])
64
 
65
+ # Upload image
66
  uploaded_file = st.file_uploader("πŸ“‚ Choose a full-face image", type=["jpg", "jpeg", "png"])
67
 
 
68
  if uploaded_file is not None:
69
  st.write("πŸ“· Image uploaded. Processing...")
70
+ full_image = Image.open(uploaded_file).convert("RGB")
71
  st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
72
 
73
+ # Crop left and right sides
74
  w, h = full_image.size
75
  mid = w // 2
76
  left_face = full_image.crop((0, 0, mid, h))
77
  right_face = full_image.crop((mid, 0, w, h))
78
 
79
+ # Resize image for stroke model input
80
  _, H, W, C = stroke_model.input_shape
81
  st.write(f"πŸ” Resizing uploaded image to: ({H}, {W}) for stroke model")
82
  stroke_input = full_image.resize((W, H))
83
  stroke_array = np.array(stroke_input).astype("float32") / 255.0
84
+ stroke_array = np.expand_dims(stroke_array, axis=0)
85
 
86
+ # Predict affected side
87
  st.write("🧠 Running stroke model prediction...")
88
  stroke_pred = stroke_model.predict(stroke_array)
89
+ affected = int(np.round(stroke_pred[0][0])) # Assuming 0 = left affected, 1 = right affected
90
 
91
+ # βœ… FIX: Interpret correctly based on stroke model
92
+ # If 0: left side affected β†’ use right for pain
93
+ # If 1: right side affected β†’ use left for pain
94
+ if affected == 0:
95
  affected_side = "left"
96
  unaffected_side = "right"
97
  unaffected_face = right_face
 
100
  unaffected_side = "left"
101
  unaffected_face = left_face
102
 
103
+ # Preprocess for pain model
104
+ unaffected_tensor = transform(unaffected_face).unsqueeze(0)
105
 
106
+ # Predict pain score (PSPI)
107
  st.write("πŸ“ˆ Predicting pain score...")
108
  with torch.no_grad():
109
  output = pain_model(unaffected_tensor)
110
+ pspi_score = output.item()
111
 
112
+ # Display prediction results
113
+ st.subheader("πŸ” Prediction Results")
114
  st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
115
+
116
+ st.write(f"**🧭 Affected Side (face POV):** `{affected_side}`")
117
+ st.write(f"**βœ… Unaffected Side (face POV):** `{unaffected_side}`")
118
  st.write(f"**🎯 Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
119
+ st.write(f"**πŸ“Š Stroke model raw output:** `{stroke_pred[0][0]:.4f}`")
120
+
121
+ # Pain scale explanation
122
+ st.markdown(
123
+ """
124
+ ---
125
+ ### ℹ️ **What is PSPI?**
126
+ The **Prkachin and Solomon Pain Intensity (PSPI)** score ranges from **0 to 6**:
127
+
128
+ - `0`: No pain
129
+ - `1–2`: Mild pain
130
+ - `3–4`: Moderate pain
131
+ - `5–6`: Severe pain
132
+
133
+ This score is computed from facial action units such as brow lowering, eye closure, and cheek raising.
134
+ """
135
+ )