AdhamQQ commited on
Commit
006408c
Β·
verified Β·
1 Parent(s): 73f0331

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -41
app.py CHANGED
@@ -9,10 +9,12 @@ from torchvision.models import resnet18
9
  import os
10
  import requests
11
 
12
- st.title("Stroke Patient Pain Intensity Detector")
 
13
  st.markdown("Upload a full-face image of a stroke patient. The app will detect the affected side and predict pain intensity using the unaffected side.")
14
  st.write("πŸ”§ App started. Preparing to download models...")
15
 
 
16
  @st.cache_resource
17
  def download_models():
18
  model_urls = {
@@ -20,6 +22,7 @@ def download_models():
20
  "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
21
  }
22
 
 
23
  for filename, url in model_urls.items():
24
  if not os.path.exists(filename):
25
  st.write(f"πŸ“₯ Downloading {filename}...")
@@ -31,70 +34,78 @@ def download_models():
31
  st.write(f"βœ”οΈ {filename} already exists.")
32
 
33
  st.write("πŸ“¦ Loading models...")
34
- stroke_model = load_model("cnn_stroke_model.keras")
35
 
 
 
36
 
 
37
  pain_model = resnet18(weights=None)
38
- pain_model.fc = nn.Linear(pain_model.fc.in_features, 1)
39
  pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
40
  pain_model.eval()
41
 
42
-
43
  return stroke_model, pain_model
44
 
 
45
  stroke_model, pain_model = download_models()
46
 
 
47
  transform = transforms.Compose([
48
- transforms.Resize((224, 224)),
49
- transforms.ToTensor(),
50
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
51
- std=[0.229, 0.224, 0.225])
52
  ])
53
 
54
- uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
 
55
 
 
56
  if uploaded_file is not None:
57
  st.write("πŸ“· Image uploaded. Processing...")
58
- full_image = Image.open(uploaded_file).convert("RGB")
59
  st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
60
 
 
61
  w, h = full_image.size
62
  mid = w // 2
63
  left_face = full_image.crop((0, 0, mid, h))
64
  right_face = full_image.crop((mid, 0, w, h))
65
 
66
- # πŸ” Automatically resize image based on model input
67
  _, H, W, C = stroke_model.input_shape
68
  st.write(f"πŸ” Resizing uploaded image to: ({H}, {W}) for stroke model")
69
-
70
  stroke_input = full_image.resize((W, H))
71
  stroke_array = np.array(stroke_input).astype("float32") / 255.0
72
- stroke_array = np.expand_dims(stroke_array, axis=0)
73
-
74
- st.write("🧠 Running stroke model prediction...")
75
- stroke_pred = stroke_model.predict(stroke_array)
76
- affected = int(np.round(stroke_pred[0][0])) # 1 = left affected, 0 = right affected
77
-
78
- # Face POV logic: patient's left/right
79
- if affected == 1:
80
- affected_side = "left"
81
- unaffected_side = "right"
82
- unaffected_face = right_face
83
- else:
84
- affected_side = "right"
85
- unaffected_side = "left"
86
- unaffected_face = left_face
87
-
88
- unaffected_tensor = transform(unaffected_face).unsqueeze(0)
89
-
90
- st.write("πŸ“ˆ Predicting pain score...")
91
- with torch.no_grad():
92
- output = pain_model(unaffected_tensor)
93
- pspi_score = output.item()
94
-
95
- st.subheader("Prediction Results")
96
- st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
97
- st.write(f"**Affected side (face POV):** {affected_side}")
98
- st.write(f"**Unaffected side (face POV):** {unaffected_side}")
99
- st.write(f"**Predicted PSPI Pain Score:** {pspi_score:.3f}")
100
- st.write(f"**Stroke model raw output:** {stroke_pred[0][0]}")
 
 
 
 
 
9
  import os
10
  import requests
11
 
12
+ # UI title and instructions
13
+ st.title("🧠 Stroke Patient Pain Intensity Detector")
14
  st.markdown("Upload a full-face image of a stroke patient. The app will detect the affected side and predict pain intensity using the unaffected side.")
15
  st.write("πŸ”§ App started. Preparing to download models...")
16
 
17
+ # Function to download and load models with caching
18
  @st.cache_resource
19
  def download_models():
20
  model_urls = {
 
22
  "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
23
  }
24
 
25
+ # Download models if not already downloaded
26
  for filename, url in model_urls.items():
27
  if not os.path.exists(filename):
28
  st.write(f"πŸ“₯ Downloading {filename}...")
 
34
  st.write(f"βœ”οΈ {filename} already exists.")
35
 
36
  st.write("πŸ“¦ Loading models...")
 
37
 
38
+ # Load stroke side classification model (Keras)
39
+ stroke_model = load_model("cnn_stroke_model.keras")
40
 
41
+ # Load pain intensity prediction model (PyTorch ResNet18)
42
  pain_model = resnet18(weights=None)
43
+ pain_model.fc = nn.Linear(pain_model.fc.in_features, 1) # One output: PSPI pain score
44
  pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
45
  pain_model.eval()
46
 
 
47
  return stroke_model, pain_model
48
 
49
+ # Load models
50
  stroke_model, pain_model = download_models()
51
 
52
+ # Define preprocessing pipeline for PyTorch pain model
53
  transform = transforms.Compose([
54
+ transforms.Resize((224, 224)), # Resize to ResNet18 input
55
+ transforms.ToTensor(), # Convert to tensor
56
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats
 
57
  ])
58
 
59
+ # File uploader widget
60
+ uploaded_file = st.file_uploader("πŸ“‚ Choose a full-face image", type=["jpg", "jpeg", "png"])
61
 
62
+ # If an image is uploaded, start processing
63
  if uploaded_file is not None:
64
  st.write("πŸ“· Image uploaded. Processing...")
65
+ full_image = Image.open(uploaded_file).convert("RGB") # Ensure RGB format
66
  st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
67
 
68
+ # Crop left and right halves of the face
69
  w, h = full_image.size
70
  mid = w // 2
71
  left_face = full_image.crop((0, 0, mid, h))
72
  right_face = full_image.crop((mid, 0, w, h))
73
 
74
+ # Resize full image for stroke model input
75
  _, H, W, C = stroke_model.input_shape
76
  st.write(f"πŸ” Resizing uploaded image to: ({H}, {W}) for stroke model")
 
77
  stroke_input = full_image.resize((W, H))
78
  stroke_array = np.array(stroke_input).astype("float32") / 255.0
79
+ stroke_array = np.expand_dims(stroke_array, axis=0) # Shape: (1, H, W, C)
80
+
81
+ # Predict affected side using stroke model
82
+ st.write("🧠 Running stroke model prediction...")
83
+ stroke_pred = stroke_model.predict(stroke_array)
84
+ affected = int(np.round(stroke_pred[0][0])) # Output: 1 = left affected, 0 = right affected
85
+
86
+ # Select unaffected side and label sides (from patient's perspective)
87
+ if affected == 1:
88
+ affected_side = "left"
89
+ unaffected_side = "right"
90
+ unaffected_face = right_face
91
+ else:
92
+ affected_side = "right"
93
+ unaffected_side = "left"
94
+ unaffected_face = left_face
95
+
96
+ # Preprocess the unaffected side for the PyTorch model
97
+ unaffected_tensor = transform(unaffected_face).unsqueeze(0) # Shape: (1, 3, 224, 224)
98
+
99
+ # Predict pain score
100
+ st.write("πŸ“ˆ Predicting pain score...")
101
+ with torch.no_grad():
102
+ output = pain_model(unaffected_tensor)
103
+ pspi_score = output.item() # Extract float value
104
+
105
+ # Display results
106
+ st.subheader("🧾 Prediction Results")
107
+ st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
108
+ st.write(f"**🧭 Affected side (face POV):** `{affected_side}`")
109
+ st.write(f"**βœ… Unaffected side (face POV):** `{unaffected_side}`")
110
+ st.write(f"**🎯 Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
111
+ st.write(f"**πŸ“Š Stroke model raw output:** `{stroke_pred[0][0]}`")