File size: 5,460 Bytes
e036c2c
 
 
0414367
e036c2c
 
12d6038
8096d7a
e036c2c
 
9745932
11b70ce
e036c2c
9745932
006408c
d25279a
9745932
 
 
 
 
 
 
 
 
 
 
e036c2c
 
0414367
e5b0777
68b7985
e036c2c
0414367
68b7985
9745932
68b7985
 
 
9745932
 
 
 
006408c
a94d467
0fb1aba
a94d467
 
9745932
68b7985
 
 
e036c2c
9745932
e036c2c
0fb1aba
 
d25279a
e036c2c
 
9745932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import streamlit as st
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
from tensorflow.keras.models import load_model
from torchvision.models import resnet18
import os
import requests
import mediapipe as mp
import cv2

# App title
st.title("🧠 Stroke Patient Pain Intensity Detector")

# Instructions
st.markdown(
    """
    Upload a full-face image of a stroke patient.
    The app will detect the **affected facial side** using a stroke classification model,
    and then use the **unaffected side** to predict **pain intensity** (PSPI score).
    """
)
st.write("πŸ”§ Initializing and downloading models...")

# Download and load models
@st.cache_resource
def download_models():
    model_urls = {
        "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
        "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
    }
    for filename, url in model_urls.items():
        if not os.path.exists(filename):
            st.write(f"πŸ“₯ Downloading {filename}...")
            r = requests.get(url)
            with open(filename, "wb") as f:
                f.write(r.content)
            st.success(f"βœ… {filename} downloaded.")
        else:
            st.write(f"βœ”οΈ {filename} already exists.")

    stroke_model = load_model("cnn_stroke_model.keras")
    pain_model = resnet18(weights=None)
    pain_model.fc = nn.Linear(pain_model.fc.in_features, 1)
    pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
    pain_model.eval()

    return stroke_model, pain_model

stroke_model, pain_model = download_models()

# Preprocessing for pain model
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# MediaPipe Face Detection
mp_face = mp.solutions.face_detection
mp_draw = mp.solutions.drawing_utils

# Upload UI
uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    st.write("πŸ“· Image uploaded. Detecting face...")
    full_image = Image.open(uploaded_file).convert("RGB")
    img_np = np.array(full_image)

    with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.6) as detector:
        results = detector.process(cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR))

        if not results.detections:
            st.error("❌ No face detected. Please upload a clear frontal face image.")
            st.stop()

        # Use first detection
        detection = results.detections[0]
        bboxC = detection.location_data.relative_bounding_box
        ih, iw, _ = img_np.shape
        x = int(bboxC.xmin * iw)
        y = int(bboxC.ymin * ih)
        w = int(bboxC.width * iw)
        h = int(bboxC.height * ih)

        face_crop = full_image.crop((x, y, x + w, y + h))
        st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)

        # Split halves (face POV)
        fw, fh = face_crop.size
        fmid = fw // 2
        patient_right = face_crop.crop((0, 0, fmid, fh))    # viewer's left
        patient_left = face_crop.crop((fmid, 0, fw, fh))    # viewer's right

        # Stroke prediction input
        _, H, W, C = stroke_model.input_shape
        stroke_input = face_crop.resize((W, H))
        stroke_array = np.array(stroke_input).astype("float32") / 255.0
        stroke_array = np.expand_dims(stroke_array, axis=0)

        st.write("🧠 Predicting affected side of the face...")
        stroke_pred = stroke_model.predict(stroke_array)
        stroke_raw = stroke_pred[0][0]
        affected = int(np.round(stroke_raw))  # 0 = left affected, 1 = right affected

        if affected == 0:
            affected_side = "left"
            unaffected_side = "right"
            unaffected_face = patient_right
        else:
            affected_side = "right"
            unaffected_side = "left"
            unaffected_face = patient_left

        # Pain prediction
        st.write("πŸ“ˆ Predicting PSPI pain score from unaffected side...")
        input_tensor = transform(unaffected_face).unsqueeze(0)
        with torch.no_grad():
            output = pain_model(input_tensor)
            raw_score = output.item()
            pspi_score = max(0.0, min(raw_score, 6.0))

        # Display results
        st.subheader("πŸ” Prediction Results")
        st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
        st.write(f"**🧭 Affected Side (face POV):** `{affected_side}`")
        st.write(f"**βœ… Unaffected Side (face POV):** `{unaffected_side}`")
        st.write(f"**🎯 Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
        st.write(f"**πŸ“ˆ Raw Pain Model Output:** `{raw_score:.3f}`")
        st.write(f"**πŸ“Š Stroke Model Raw Output:** `{stroke_raw:.4f}`")

        st.markdown(
            """
            ---
            ### ℹ️ Stroke Model Output
            - Output is between `0` and `1`
            - Closer to `0` = Left side is affected
            - Closer to `1` = Right side is affected
            
            ### ℹ️ PSPI Score Scale
            - `0`: No pain  
            - `1–2`: Mild pain  
            - `3–4`: Moderate pain  
            - `5–6`: Severe pain
            """
        )