import streamlit as st import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np from tensorflow.keras.models import load_model from torchvision.models import resnet18 import os import requests import mediapipe as mp import cv2 # App title st.title("🧠 Stroke Patient Pain Intensity Detector") # Instructions st.markdown( """ Upload a full-face image of a stroke patient. The app will detect the **affected facial side** using a stroke classification model, and then use the **unaffected side** to predict **pain intensity** (PSPI score). """ ) st.write("🔧 Initializing and downloading models...") # Download and load models @st.cache_resource def download_models(): model_urls = { "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras", "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth" } for filename, url in model_urls.items(): if not os.path.exists(filename): st.write(f"đŸ“Ĩ Downloading {filename}...") r = requests.get(url) with open(filename, "wb") as f: f.write(r.content) st.success(f"✅ {filename} downloaded.") else: st.write(f"âœ”ī¸ {filename} already exists.") stroke_model = load_model("cnn_stroke_model.keras") pain_model = resnet18(weights=None) pain_model.fc = nn.Linear(pain_model.fc.in_features, 1) pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu"))) pain_model.eval() return stroke_model, pain_model stroke_model, pain_model = download_models() # Preprocessing for pain model transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # MediaPipe Face Detection mp_face = mp.solutions.face_detection mp_draw = mp.solutions.drawing_utils # Upload UI uploaded_file = st.file_uploader("📂 Upload a full-face image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: st.write("📷 Image uploaded. Detecting face...") full_image = Image.open(uploaded_file).convert("RGB") img_np = np.array(full_image) with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.6) as detector: results = detector.process(cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)) if not results.detections: st.error("❌ No face detected. Please upload a clear frontal face image.") st.stop() # Use first detection detection = results.detections[0] bboxC = detection.location_data.relative_bounding_box ih, iw, _ = img_np.shape x = int(bboxC.xmin * iw) y = int(bboxC.ymin * ih) w = int(bboxC.width * iw) h = int(bboxC.height * ih) face_crop = full_image.crop((x, y, x + w, y + h)) st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True) # Split halves (face POV) fw, fh = face_crop.size fmid = fw // 2 patient_right = face_crop.crop((0, 0, fmid, fh)) # viewer's left patient_left = face_crop.crop((fmid, 0, fw, fh)) # viewer's right # Stroke prediction input _, H, W, C = stroke_model.input_shape stroke_input = face_crop.resize((W, H)) stroke_array = np.array(stroke_input).astype("float32") / 255.0 stroke_array = np.expand_dims(stroke_array, axis=0) st.write("🧠 Predicting affected side of the face...") stroke_pred = stroke_model.predict(stroke_array) stroke_raw = stroke_pred[0][0] affected = int(np.round(stroke_raw)) # 0 = left affected, 1 = right affected if affected == 0: affected_side = "left" unaffected_side = "right" unaffected_face = patient_right else: affected_side = "right" unaffected_side = "left" unaffected_face = patient_left # Pain prediction st.write("📈 Predicting PSPI pain score from unaffected side...") input_tensor = transform(unaffected_face).unsqueeze(0) with torch.no_grad(): output = pain_model(input_tensor) raw_score = output.item() pspi_score = max(0.0, min(raw_score, 6.0)) # Display results st.subheader("🔍 Prediction Results") st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300) st.write(f"**🧭 Affected Side (face POV):** `{affected_side}`") st.write(f"**✅ Unaffected Side (face POV):** `{unaffected_side}`") st.write(f"**đŸŽ¯ Predicted PSPI Pain Score:** `{pspi_score:.3f}`") st.write(f"**📈 Raw Pain Model Output:** `{raw_score:.3f}`") st.write(f"**📊 Stroke Model Raw Output:** `{stroke_raw:.4f}`") st.markdown( """ --- ### â„šī¸ Stroke Model Output - Output is between `0` and `1` - Closer to `0` = Left side is affected - Closer to `1` = Right side is affected ### â„šī¸ PSPI Score Scale - `0`: No pain - `1–2`: Mild pain - `3–4`: Moderate pain - `5–6`: Severe pain """ )