Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from tensorflow.keras.models import load_model | |
| from torchvision.models import resnet18 | |
| import os | |
| import requests | |
| import mediapipe as mp | |
| import cv2 | |
| # App title | |
| st.title("π§ Stroke Patient Pain Intensity Detector") | |
| # Instructions | |
| st.markdown( | |
| """ | |
| Upload a full-face image of a stroke patient. | |
| The app will detect the **affected facial side** using a stroke classification model, | |
| and then use the **unaffected side** to predict **pain intensity** (PSPI score). | |
| """ | |
| ) | |
| st.write("π§ Initializing and downloading models...") | |
| # Download and load models | |
| def download_models(): | |
| model_urls = { | |
| "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras", | |
| "pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/pain_model.pth" | |
| } | |
| for filename, url in model_urls.items(): | |
| if not os.path.exists(filename): | |
| st.write(f"π₯ Downloading {filename}...") | |
| r = requests.get(url) | |
| with open(filename, "wb") as f: | |
| f.write(r.content) | |
| st.success(f"β {filename} downloaded.") | |
| else: | |
| st.write(f"βοΈ {filename} already exists.") | |
| stroke_model = load_model("cnn_stroke_model.keras") | |
| pain_model = resnet18(weights=None) | |
| pain_model.fc = nn.Linear(pain_model.fc.in_features, 1) | |
| pain_model.load_state_dict(torch.load("pain_model.pth", map_location=torch.device("cpu"))) | |
| pain_model.eval() | |
| return stroke_model, pain_model | |
| stroke_model, pain_model = download_models() | |
| # Preprocessing for pain model | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # MediaPipe Face Detection | |
| mp_face = mp.solutions.face_detection | |
| mp_draw = mp.solutions.drawing_utils | |
| # Upload UI | |
| uploaded_file = st.file_uploader("π Upload a full-face image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| st.write("π· Image uploaded. Detecting face...") | |
| full_image = Image.open(uploaded_file).convert("RGB") | |
| img_np = np.array(full_image) | |
| with mp_face.FaceDetection(model_selection=1, min_detection_confidence=0.6) as detector: | |
| results = detector.process(cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)) | |
| if not results.detections: | |
| st.error("β No face detected. Please upload a clear frontal face image.") | |
| st.stop() | |
| # Use first detection | |
| detection = results.detections[0] | |
| bboxC = detection.location_data.relative_bounding_box | |
| ih, iw, _ = img_np.shape | |
| x = int(bboxC.xmin * iw) | |
| y = int(bboxC.ymin * ih) | |
| w = int(bboxC.width * iw) | |
| h = int(bboxC.height * ih) | |
| face_crop = full_image.crop((x, y, x + w, y + h)) | |
| st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True) | |
| # Split halves (face POV) | |
| fw, fh = face_crop.size | |
| fmid = fw // 2 | |
| patient_right = face_crop.crop((0, 0, fmid, fh)) # viewer's left | |
| patient_left = face_crop.crop((fmid, 0, fw, fh)) # viewer's right | |
| # Stroke prediction input | |
| _, H, W, C = stroke_model.input_shape | |
| stroke_input = face_crop.resize((W, H)) | |
| stroke_array = np.array(stroke_input).astype("float32") / 255.0 | |
| stroke_array = np.expand_dims(stroke_array, axis=0) | |
| st.write("π§ Predicting affected side of the face...") | |
| stroke_pred = stroke_model.predict(stroke_array) | |
| stroke_raw = stroke_pred[0][0] | |
| affected = int(np.round(stroke_raw)) # 0 = left affected, 1 = right affected | |
| if affected == 0: | |
| affected_side = "left" | |
| unaffected_side = "right" | |
| unaffected_face = patient_right | |
| else: | |
| affected_side = "right" | |
| unaffected_side = "left" | |
| unaffected_face = patient_left | |
| # Pain prediction | |
| st.write("π Predicting PSPI pain score from unaffected side...") | |
| input_tensor = transform(unaffected_face).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = pain_model(input_tensor) | |
| raw_score = output.item() | |
| pspi_score = max(0.0, min(raw_score, 6.0)) | |
| # Display results | |
| st.subheader("π Prediction Results") | |
| st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300) | |
| st.write(f"**π§ Affected Side (face POV):** `{affected_side}`") | |
| st.write(f"**β Unaffected Side (face POV):** `{unaffected_side}`") | |
| st.write(f"**π― Predicted PSPI Pain Score:** `{pspi_score:.3f}`") | |
| st.write(f"**π Raw Pain Model Output:** `{raw_score:.3f}`") | |
| st.write(f"**π Stroke Model Raw Output:** `{stroke_raw:.4f}`") | |
| st.markdown( | |
| """ | |
| --- | |
| ### βΉοΈ Stroke Model Output | |
| - Output is between `0` and `1` | |
| - Closer to `0` = Left side is affected | |
| - Closer to `1` = Right side is affected | |
| ### βΉοΈ PSPI Score Scale | |
| - `0`: No pain | |
| - `1β2`: Mild pain | |
| - `3β4`: Moderate pain | |
| - `5β6`: Severe pain | |
| """ | |
| ) | |