AdhamQQ commited on
Commit
e036c2c
·
verified ·
1 Parent(s): 4747664

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ from tensorflow.keras.models import load_model
8
+ import os
9
+ import requests
10
+
11
+ st.title("Stroke Patient Pain Intensity Detector")
12
+ st.markdown("Upload a full-face image of a stroke patient. The app will detect the affected side and predict pain intensity using the unaffected side.")
13
+ st.write("🔧 App started. Preparing to download models...")
14
+
15
+ @st.cache_resource
16
+ def download_models():
17
+ model_urls = {
18
+ "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
19
+ "right_side_pain_model.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_model.pth"
20
+ }
21
+
22
+ for filename, url in model_urls.items():
23
+ if not os.path.exists(filename):
24
+ st.write(f"📥 Downloading {filename}...")
25
+ r = requests.get(url)
26
+ with open(filename, "wb") as f:
27
+ f.write(r.content)
28
+ st.write(f"✅ {filename} downloaded.")
29
+ else:
30
+ st.write(f"✔️ {filename} already exists.")
31
+
32
+ st.write("📦 Loading models...")
33
+ stroke_model = load_model("cnn_stroke_model.keras")
34
+
35
+ class PainRegressor(nn.Module):
36
+ def __init__(self):
37
+ super(PainRegressor, self).__init__()
38
+ from torchvision.models import resnet18, ResNet18_Weights
39
+ self.base = resnet18(weights=ResNet18_Weights.DEFAULT)
40
+ num_features = self.base.fc.in_features
41
+ self.base.fc = nn.Linear(num_features, 1)
42
+ def forward(self, x):
43
+ return self.base(x)
44
+
45
+ pain_model = PainRegressor()
46
+ pain_model.load_state_dict(torch.load("right_side_pain_model.pth", map_location=torch.device('cpu')))
47
+ pain_model.eval()
48
+ st.write("✅ Models loaded.")
49
+
50
+ return stroke_model, pain_model
51
+
52
+ stroke_model, pain_model = download_models()
53
+
54
+ transform = transforms.Compose([
55
+ transforms.Resize((224, 224)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
58
+ std=[0.229, 0.224, 0.225])
59
+ ])
60
+
61
+ uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
62
+
63
+ if uploaded_file is not None:
64
+ st.write("📷 Image uploaded. Processing...")
65
+ full_image = Image.open(uploaded_file).convert("RGB")
66
+ st.image(full_image, caption="Uploaded Full-Face Image", use_column_width=True)
67
+
68
+ w, h = full_image.size
69
+ mid = w // 2
70
+ left_face = full_image.crop((0, 0, mid, h))
71
+ right_face = full_image.crop((mid, 0, w, h))
72
+
73
+ stroke_input = full_image.resize((128, 128))
74
+ stroke_array = np.array(stroke_input).astype("float32") / 255.0
75
+ stroke_array = np.expand_dims(stroke_array, axis=0)
76
+ st.write("🧠 Running stroke model prediction...")
77
+ stroke_pred = stroke_model.predict(stroke_array)
78
+ affected = int(np.round(stroke_pred[0][0]))
79
+
80
+ unaffected_face = right_face if affected == 0 else left_face
81
+ unaffected_tensor = transform(unaffected_face).unsqueeze(0)
82
+
83
+ st.write("📈 Predicting pain score...")
84
+ with torch.no_grad():
85
+ output = pain_model(unaffected_tensor)
86
+ pspi_score = output.item()
87
+
88
+ st.subheader("Prediction Results")
89
+ st.image(unaffected_face, caption="Unaffected Side Used for Pain Detection", width=300)
90
+ st.write(f"**Affected side:** {'left' if affected == 0 else 'right'}")
91
+ st.write(f"**Unaffected side:** {'right' if affected == 0 else 'left'}")
92
+ st.write(f"**Predicted PSPI Pain Score:** {round(pspi_score, 3)}")