AdhamQQ's picture
Update app.py
9745932 verified
import streamlit as st
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
from tensorflow.keras.models import load_model
from torchvision.models import resnet18
import os
import requests
import mediapipe as mp
import cv2
# App title
st.title("🧠 Stroke Patient Pain Intensity Detector")
# Instructions
st.markdown(
"""
Upload a full-face image of a stroke patient.
The app will detect the **affected facial side** using a stroke classification model,
and then use the **unaffected side** to predict **pain intensity** (PSPI score).
"""
)
st.write("πŸ”§ Initializing and downloading models...")
# Download and load models
@st.cache_resource
def download_models():
model_urls = {
"cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
"pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth"
}
for filename, url in model_urls.items():
if not os.path.exists(filename):
st.write(f"πŸ“₯ Downloading {filename}...")
r = requests.get(url)
with open(filename, "wb") as f:
f.write(r.content)
st.success(f"βœ… {filename} downloaded.")
else:
st.write(f"βœ”οΈ {filename} already exists.")
stroke_model = load_model("cnn_stroke_model.keras")
pain_model = resnet18(weights=None)
pain_model.fc = nn.Linear(pain_model.fc.in_features, 1)
pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu")))
pain_model.eval()
return stroke_model, pain_model
stroke_model, pain_model = download_models()
# Preprocessing for pain model
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# MediaPipe Face Detection
mp_face = mp.solutions.face_detection
mp_draw = mp.solutions.drawing_utils
# Upload UI
uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
st.write("πŸ“· Image uploaded. Detecting face...")
full_image = Image.open(uploaded_file).convert("RGB")
img_np = np.array(full_image)
with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.6) as detector:
results = detector.process(cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR))
if not results.detections:
st.error("❌ No face detected. Please upload a clear frontal face image.")
st.stop()
# Use first detection
detection = results.detections[0]
bboxC = detection.location_data.relative_bounding_box
ih, iw, _ = img_np.shape
x = int(bboxC.xmin * iw)
y = int(bboxC.ymin * ih)
w = int(bboxC.width * iw)
h = int(bboxC.height * ih)
face_crop = full_image.crop((x, y, x + w, y + h))
st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
# Split halves (face POV)
fw, fh = face_crop.size
fmid = fw // 2
patient_right = face_crop.crop((0, 0, fmid, fh)) # viewer's left
patient_left = face_crop.crop((fmid, 0, fw, fh)) # viewer's right
# Stroke prediction input
_, H, W, C = stroke_model.input_shape
stroke_input = face_crop.resize((W, H))
stroke_array = np.array(stroke_input).astype("float32") / 255.0
stroke_array = np.expand_dims(stroke_array, axis=0)
st.write("🧠 Predicting affected side of the face...")
stroke_pred = stroke_model.predict(stroke_array)
stroke_raw = stroke_pred[0][0]
affected = int(np.round(stroke_raw)) # 0 = left affected, 1 = right affected
if affected == 0:
affected_side = "left"
unaffected_side = "right"
unaffected_face = patient_right
else:
affected_side = "right"
unaffected_side = "left"
unaffected_face = patient_left
# Pain prediction
st.write("πŸ“ˆ Predicting PSPI pain score from unaffected side...")
input_tensor = transform(unaffected_face).unsqueeze(0)
with torch.no_grad():
output = pain_model(input_tensor)
raw_score = output.item()
pspi_score = max(0.0, min(raw_score, 6.0))
# Display results
st.subheader("πŸ” Prediction Results")
st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
st.write(f"**🧭 Affected Side (face POV):** `{affected_side}`")
st.write(f"**βœ… Unaffected Side (face POV):** `{unaffected_side}`")
st.write(f"**🎯 Predicted PSPI Pain Score:** `{pspi_score:.3f}`")
st.write(f"**πŸ“ˆ Raw Pain Model Output:** `{raw_score:.3f}`")
st.write(f"**πŸ“Š Stroke Model Raw Output:** `{stroke_raw:.4f}`")
st.markdown(
"""
---
### ℹ️ Stroke Model Output
- Output is between `0` and `1`
- Closer to `0` = Left side is affected
- Closer to `1` = Right side is affected
### ℹ️ PSPI Score Scale
- `0`: No pain
- `1–2`: Mild pain
- `3–4`: Moderate pain
- `5–6`: Severe pain
"""
)