Spaces:
Sleeping
Sleeping
File size: 5,460 Bytes
e036c2c 0414367 e036c2c 12d6038 8096d7a e036c2c 9745932 11b70ce e036c2c 9745932 006408c d25279a 9745932 e036c2c 0414367 e5b0777 68b7985 e036c2c 0414367 68b7985 9745932 68b7985 9745932 006408c a94d467 0fb1aba a94d467 9745932 68b7985 e036c2c 9745932 e036c2c 0fb1aba d25279a e036c2c 9745932 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import streamlit as st
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
from tensorflow.keras.models import load_model
from torchvision.models import resnet18
import os
import requests
import mediapipe as mp
import cv2
# App title
st.title("π§ Stroke Patient Pain Intensity Detector")
# Instructions
st.markdown(
"""
Upload a full-face image of a stroke patient.
The app will detect the **affected facial side** using a stroke classification model,
and then use the **unaffected side** to predict **pain intensity** (PSPI score).
"""
)
st.write("π§ Initializing and downloading models...")
# Download and load models
@st.cache_resource
def download_models():
model_urls = {
"cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
"pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
}
for filename, url in model_urls.items():
if not os.path.exists(filename):
st.write(f"π₯ Downloading {filename}...")
r = requests.get(url)
with open(filename, "wb") as f:
f.write(r.content)
st.success(f"β
{filename} downloaded.")
else:
st.write(f"βοΈ {filename} already exists.")
stroke_model = load_model("cnn_stroke_model.keras")
pain_model = resnet18(weights=None)
pain_model.fc = nn.Linear(pain_model.fc.in_features, 1)
pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
pain_model.eval()
return stroke_model, pain_model
stroke_model, pain_model = download_models()
# Preprocessing for pain model
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# MediaPipe Face Detection
mp_face = mp.solutions.face_detection
mp_draw = mp.solutions.drawing_utils
# Upload UI
uploaded_file = st.file_uploader("π Upload a full-face image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
st.write("π· Image uploaded. Detecting face...")
full_image = Image.open(uploaded_file).convert("RGB")
img_np = np.array(full_image)
with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.6) as detector:
results = detector.process(cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR))
if not results.detections:
st.error("β No face detected. Please upload a clear frontal face image.")
st.stop()
# Use first detection
detection = results.detections[0]
bboxC = detection.location_data.relative_bounding_box
ih, iw, _ = img_np.shape
x = int(bboxC.xmin * iw)
y = int(bboxC.ymin * ih)
w = int(bboxC.width * iw)
h = int(bboxC.height * ih)
face_crop = full_image.crop((x, y, x + w, y + h))
st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
# Split halves (face POV)
fw, fh = face_crop.size
fmid = fw // 2
patient_right = face_crop.crop((0, 0, fmid, fh)) # viewer's left
patient_left = face_crop.crop((fmid, 0, fw, fh)) # viewer's right
# Stroke prediction input
_, H, W, C = stroke_model.input_shape
stroke_input = face_crop.resize((W, H))
stroke_array = np.array(stroke_input).astype("float32") / 255.0
stroke_array = np.expand_dims(stroke_array, axis=0)
st.write("π§ Predicting affected side of the face...")
stroke_pred = stroke_model.predict(stroke_array)
stroke_raw = stroke_pred[0][0]
affected = int(np.round(stroke_raw)) # 0 = left affected, 1 = right affected
if affected == 0:
affected_side = "left"
unaffected_side = "right"
unaffected_face = patient_right
else:
affected_side = "right"
unaffected_side = "left"
unaffected_face = patient_left
# Pain prediction
st.write("π Predicting PSPI pain score from unaffected side...")
input_tensor = transform(unaffected_face).unsqueeze(0)
with torch.no_grad():
output = pain_model(input_tensor)
raw_score = output.item()
pspi_score = max(0.0, min(raw_score, 6.0))
# Display results
st.subheader("π Prediction Results")
st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
st.write(f"**π§ Affected Side (face POV):** `{affected_side}`")
st.write(f"**β
Unaffected Side (face POV):** `{unaffected_side}`")
st.write(f"**π― Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
st.write(f"**π Raw Pain Model Output:** `{raw_score:.3f}`")
st.write(f"**π Stroke Model Raw Output:** `{stroke_raw:.4f}`")
st.markdown(
"""
---
### βΉοΈ Stroke Model Output
- Output is between `0` and `1`
- Closer to `0` = Left side is affected
- Closer to `1` = Right side is affected
### βΉοΈ PSPI Score Scale
- `0`: No pain
- `1β2`: Mild pain
- `3β4`: Moderate pain
- `5β6`: Severe pain
"""
)
|