AdhamQQ commited on
Commit
96b0baf
Β·
verified Β·
1 Parent(s): e74518c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -66
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
 
5
  from PIL import Image
6
  import numpy as np
7
  from tensorflow.keras.models import load_model
@@ -14,123 +15,90 @@ st.markdown("Upload a full-face image. The system will detect the affected side
14
 
15
  @st.cache_resource
16
  def download_models():
17
- # 🧬 Ensure LFS-tracked files are downloaded
18
- st.write("πŸ”„ Checking and pulling Git LFS files...")
19
- os.system("git lfs pull")
20
-
21
- model_urls = {
22
  "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
23
  "left_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/left_side_pain_classifier.pth",
24
- "right_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_classifier.pth",
25
  }
26
 
27
- for filename, url in model_urls.items():
28
- if not os.path.exists(filename) or os.path.getsize(filename) < 1000:
29
- st.write(f"πŸ“₯ Downloading {filename}...")
30
  try:
31
- with requests.get(url, stream=True) as r:
32
- r.raise_for_status()
33
- with open(filename, "wb") as f:
34
- for chunk in r.iter_content(chunk_size=8192):
35
- if chunk:
36
- f.write(chunk)
37
- st.success(f"{filename} downloaded.")
38
  except Exception as e:
39
- st.error(f"❌ Failed to download {filename}: {e}")
40
  st.stop()
41
  else:
42
- st.write(f"βœ”οΈ {filename} already exists.")
43
 
44
- # βœ… Haar cascade for face detection
45
  haar_path = "haarcascade_frontalface_default.xml"
46
  haar_url = "https://raw.githubusercontent.com/opencv/opencv/master/data/haarcascades/haarcascade_frontalface_default.xml"
47
  if not os.path.exists(haar_path):
48
- st.write("πŸ“₯ Downloading Haar Cascade...")
49
  r = requests.get(haar_url)
50
  with open(haar_path, "wb") as f:
51
  f.write(r.content)
52
 
53
- if not os.path.exists(haar_path):
54
- st.error("❌ Haar Cascade not found.")
55
- st.stop()
56
-
57
- # βœ… Check and load stroke model
58
- model_path = "cnn_stroke_model.keras"
59
- if not os.path.exists(model_path) or os.path.getsize(model_path) < 10000:
60
- st.error("❌ Model file missing or too small. Likely not downloaded via Git LFS.")
61
- st.stop()
62
-
63
  try:
64
- stroke_model = load_model(model_path)
65
  except Exception as e:
66
  st.error(f"❌ Failed to load stroke model: {e}")
67
  st.stop()
68
 
69
- # βœ… Define and load PyTorch pain models
70
- class PainModel(nn.Module):
71
- def __init__(self):
72
- super(PainModel, self).__init__()
73
- from torchvision.models import resnet18, ResNet18_Weights
74
- self.convnet = resnet18(weights=ResNet18_Weights.DEFAULT)
75
- self.convnet.fc = nn.Linear(self.convnet.fc.in_features, 1)
76
-
77
- def forward(self, x):
78
- return self.convnet(x)
79
 
80
  try:
81
- left_model = PainModel()
82
- right_model = PainModel()
83
- left_model.load_state_dict(torch.load("left_side_pain_classifier.pth", map_location="cpu"))
84
- right_model.load_state_dict(torch.load("right_side_pain_classifier.pth", map_location="cpu"))
85
- left_model.eval()
86
- right_model.eval()
87
  except Exception as e:
88
  st.error(f"❌ Error loading PyTorch pain models: {e}")
89
  st.stop()
90
 
91
  return stroke_model, left_model, right_model
92
 
93
- # βœ… Load models
94
  stroke_model, left_model, right_model = download_models()
95
 
96
- # βœ… Preprocessing transform
97
  transform = transforms.Compose([
98
  transforms.Resize((224, 224)),
99
  transforms.ToTensor(),
100
- transforms.Normalize([0.485, 0.456, 0.406],
101
- [0.229, 0.224, 0.225])
102
  ])
103
 
104
- # βœ… File uploader
105
  uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])
106
 
107
  if uploaded_file is not None:
108
  full_image = Image.open(uploaded_file).convert("RGB")
109
  st.image(full_image, caption="πŸ“Έ Uploaded Full-Face Image", use_column_width=True)
110
 
111
- # βœ… Face detection
112
- cv_image = np.array(full_image)
113
- gray = cv2.cvtColor(cv_image, cv2.COLOR_RGB2GRAY)
114
  face_cascade = cv2.CascadeClassifier("haarcascade_frontalface_default.xml")
 
115
 
116
- if face_cascade.empty():
117
- st.error("❌ Haar cascade failed to load.")
118
- st.stop()
119
-
120
- faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)
121
  if len(faces) == 0:
122
- st.error("❌ No face detected. Please upload a clear image of a face.")
123
  st.stop()
124
 
125
- # βœ… Stroke prediction
126
  w, h = full_image.size
127
  mid = w // 2
128
  left_half = full_image.crop((0, 0, mid, h))
129
  right_half = full_image.crop((mid, 0, w, h))
130
 
131
  stroke_input = full_image.resize((224, 224))
132
- stroke_array = np.array(stroke_input).astype("float32") / 255.0
133
- stroke_array = np.expand_dims(stroke_array, axis=0)
134
 
135
  st.write("🧠 Predicting affected side...")
136
  stroke_pred = stroke_model.predict(stroke_array)
@@ -139,7 +107,6 @@ if uploaded_file is not None:
139
  unaffected_face = right_half if affected == 0 else left_half
140
  selected_model = right_model if affected == 0 else left_model
141
 
142
- # βœ… Pain prediction
143
  st.write("πŸ“ˆ Predicting pain...")
144
  with torch.no_grad():
145
  tensor = transform(unaffected_face).unsqueeze(0)
@@ -147,7 +114,6 @@ if uploaded_file is not None:
147
  prob = torch.sigmoid(output).item()
148
  label = 1 if prob > 0.5 else 0
149
 
150
- # βœ… Display results
151
  st.subheader("πŸ” Prediction Result")
152
  st.image(unaffected_face, caption="🧍 Unaffected Side Used", width=300)
153
  st.write(f"🧭 Affected Side: **{'left' if affected == 0 else 'right'}**")
 
2
  import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
+ from torchvision.models import resnet18
6
  from PIL import Image
7
  import numpy as np
8
  from tensorflow.keras.models import load_model
 
15
 
16
  @st.cache_resource
17
  def download_models():
18
+ model_files = {
 
 
 
 
19
  "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
20
  "left_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/left_side_pain_classifier.pth",
21
+ "right_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_classifier.pth"
22
  }
23
 
24
+ for fname, url in model_files.items():
25
+ if not os.path.exists(fname) or os.path.getsize(fname) < 10000:
26
+ st.write(f"πŸ“₯ Downloading {fname}...")
27
  try:
28
+ r = requests.get(url)
29
+ r.raise_for_status()
30
+ with open(fname, "wb") as f:
31
+ f.write(r.content)
32
+ st.success(f"βœ… {fname} downloaded.")
 
 
33
  except Exception as e:
34
+ st.error(f"❌ Could not download {fname}: {e}")
35
  st.stop()
36
  else:
37
+ st.write(f"βœ”οΈ {fname} already exists.")
38
 
39
+ # Haar cascade
40
  haar_path = "haarcascade_frontalface_default.xml"
41
  haar_url = "https://raw.githubusercontent.com/opencv/opencv/master/data/haarcascades/haarcascade_frontalface_default.xml"
42
  if not os.path.exists(haar_path):
 
43
  r = requests.get(haar_url)
44
  with open(haar_path, "wb") as f:
45
  f.write(r.content)
46
 
47
+ # Load Keras model
 
 
 
 
 
 
 
 
 
48
  try:
49
+ stroke_model = load_model("cnn_stroke_model.keras")
50
  except Exception as e:
51
  st.error(f"❌ Failed to load stroke model: {e}")
52
  st.stop()
53
 
54
+ def load_pain_model(path):
55
+ model = resnet18(weights=None)
56
+ model.fc = nn.Linear(model.fc.in_features, 1)
57
+ model.load_state_dict(torch.load(path, map_location="cpu"))
58
+ model.eval()
59
+ return model
 
 
 
 
60
 
61
  try:
62
+ left_model = load_pain_model("left_side_pain_classifier.pth")
63
+ right_model = load_pain_model("right_side_pain_classifier.pth")
 
 
 
 
64
  except Exception as e:
65
  st.error(f"❌ Error loading PyTorch pain models: {e}")
66
  st.stop()
67
 
68
  return stroke_model, left_model, right_model
69
 
 
70
  stroke_model, left_model, right_model = download_models()
71
 
 
72
  transform = transforms.Compose([
73
  transforms.Resize((224, 224)),
74
  transforms.ToTensor(),
75
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 
76
  ])
77
 
 
78
  uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])
79
 
80
  if uploaded_file is not None:
81
  full_image = Image.open(uploaded_file).convert("RGB")
82
  st.image(full_image, caption="πŸ“Έ Uploaded Full-Face Image", use_column_width=True)
83
 
84
+ # Face detection
85
+ np_img = np.array(full_image)
86
+ gray = cv2.cvtColor(np_img, cv2.COLOR_RGB2GRAY)
87
  face_cascade = cv2.CascadeClassifier("haarcascade_frontalface_default.xml")
88
+ faces = face_cascade.detectMultiScale(gray, 1.1, 5)
89
 
 
 
 
 
 
90
  if len(faces) == 0:
91
+ st.error("❌ No face detected. Please upload a clearer image.")
92
  st.stop()
93
 
94
+ # Stroke side prediction
95
  w, h = full_image.size
96
  mid = w // 2
97
  left_half = full_image.crop((0, 0, mid, h))
98
  right_half = full_image.crop((mid, 0, w, h))
99
 
100
  stroke_input = full_image.resize((224, 224))
101
+ stroke_array = np.expand_dims(np.array(stroke_input) / 255.0, axis=0)
 
102
 
103
  st.write("🧠 Predicting affected side...")
104
  stroke_pred = stroke_model.predict(stroke_array)
 
107
  unaffected_face = right_half if affected == 0 else left_half
108
  selected_model = right_model if affected == 0 else left_model
109
 
 
110
  st.write("πŸ“ˆ Predicting pain...")
111
  with torch.no_grad():
112
  tensor = transform(unaffected_face).unsqueeze(0)
 
114
  prob = torch.sigmoid(output).item()
115
  label = 1 if prob > 0.5 else 0
116
 
 
117
  st.subheader("πŸ” Prediction Result")
118
  st.image(unaffected_face, caption="🧍 Unaffected Side Used", width=300)
119
  st.write(f"🧭 Affected Side: **{'left' if affected == 0 else 'right'}**")