AdhamQQ commited on
Commit
0414367
Β·
verified Β·
1 Parent(s): 96b0baf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
- from torchvision import transforms
5
  from torchvision.models import resnet18
 
6
  from PIL import Image
7
  import numpy as np
8
  from tensorflow.keras.models import load_model
@@ -15,26 +15,26 @@ st.markdown("Upload a full-face image. The system will detect the affected side
15
 
16
  @st.cache_resource
17
  def download_models():
18
- model_files = {
19
  "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
20
  "left_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/left_side_pain_classifier.pth",
21
- "right_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_classifier.pth"
22
  }
23
 
24
- for fname, url in model_files.items():
25
- if not os.path.exists(fname) or os.path.getsize(fname) < 10000:
26
- st.write(f"πŸ“₯ Downloading {fname}...")
27
  try:
28
  r = requests.get(url)
29
  r.raise_for_status()
30
- with open(fname, "wb") as f:
31
  f.write(r.content)
32
- st.success(f"βœ… {fname} downloaded.")
33
  except Exception as e:
34
- st.error(f"❌ Could not download {fname}: {e}")
35
  st.stop()
36
  else:
37
- st.write(f"βœ”οΈ {fname} already exists.")
38
 
39
  # Haar cascade
40
  haar_path = "haarcascade_frontalface_default.xml"
@@ -51,9 +51,18 @@ def download_models():
51
  st.error(f"❌ Failed to load stroke model: {e}")
52
  st.stop()
53
 
 
 
 
 
 
 
 
 
 
 
54
  def load_pain_model(path):
55
- model = resnet18(weights=None)
56
- model.fc = nn.Linear(model.fc.in_features, 1)
57
  model.load_state_dict(torch.load(path, map_location="cpu"))
58
  model.eval()
59
  return model
@@ -72,7 +81,8 @@ stroke_model, left_model, right_model = download_models()
72
  transform = transforms.Compose([
73
  transforms.Resize((224, 224)),
74
  transforms.ToTensor(),
75
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 
76
  ])
77
 
78
  uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])
 
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
 
4
  from torchvision.models import resnet18
5
+ from torchvision import transforms
6
  from PIL import Image
7
  import numpy as np
8
  from tensorflow.keras.models import load_model
 
15
 
16
  @st.cache_resource
17
  def download_models():
18
+ model_urls = {
19
  "cnn_stroke_model.keras": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/cnn_stroke_model.keras",
20
  "left_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/left_side_pain_classifier.pth",
21
+ "right_side_pain_classifier.pth": "https://huggingface.co/AdhamQQ/cnn_stroke_model/resolve/main/right_side_pain_classifier.pth",
22
  }
23
 
24
+ for filename, url in model_urls.items():
25
+ if not os.path.exists(filename) or os.path.getsize(filename) < 10000:
26
+ st.write(f"πŸ“₯ Downloading {filename}...")
27
  try:
28
  r = requests.get(url)
29
  r.raise_for_status()
30
+ with open(filename, "wb") as f:
31
  f.write(r.content)
32
+ st.success(f"βœ… {filename} downloaded.")
33
  except Exception as e:
34
+ st.error(f"❌ Could not download {filename}: {e}")
35
  st.stop()
36
  else:
37
+ st.write(f"βœ”οΈ {filename} already exists.")
38
 
39
  # Haar cascade
40
  haar_path = "haarcascade_frontalface_default.xml"
 
51
  st.error(f"❌ Failed to load stroke model: {e}")
52
  st.stop()
53
 
54
+ # Define correct PainModel class
55
+ class PainModel(nn.Module):
56
+ def __init__(self):
57
+ super(PainModel, self).__init__()
58
+ self.convnet = resnet18(weights=None)
59
+ self.convnet.fc = nn.Linear(self.convnet.fc.in_features, 1)
60
+
61
+ def forward(self, x):
62
+ return self.convnet(x)
63
+
64
  def load_pain_model(path):
65
+ model = PainModel()
 
66
  model.load_state_dict(torch.load(path, map_location="cpu"))
67
  model.eval()
68
  return model
 
81
  transform = transforms.Compose([
82
  transforms.Resize((224, 224)),
83
  transforms.ToTensor(),
84
+ transforms.Normalize([0.485, 0.456, 0.406],
85
+ [0.229, 0.224, 0.225])
86
  ])
87
 
88
  uploaded_file = st.file_uploader("πŸ“‚ Upload a full-face image", type=["jpg", "jpeg", "png"])