fjurie commited on
Commit
09dc5fa
·
verified ·
1 Parent(s): 9f50beb

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/app.py +3 -1
  2. src/streamlit_app.py +4 -2
src/app.py CHANGED
@@ -23,7 +23,9 @@ def load_model():
23
  model.classifier[3] = nn.Linear(num_ftrs, len(class_names))
24
 
25
  # Load the state dictionary
26
- model_save_path = 'mobilenetv3_hymenoptera.pth'
 
 
27
  try:
28
  model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
29
  model.eval()
 
23
  model.classifier[3] = nn.Linear(num_ftrs, len(class_names))
24
 
25
  # Load the state dictionary
26
+ # model_save_path = 'mobilenetv3_hymenoptera.pth'
27
+ file_dir = os.path.dirname(os.path.abspath(__file__))
28
+ model_save_path = os.path.join(file_dir, 'mobilenetv3_hymenoptera.pth')
29
  try:
30
  model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
31
  model.eval()
src/streamlit_app.py CHANGED
@@ -23,7 +23,9 @@ def load_model():
23
  model.classifier[3] = nn.Linear(num_ftrs, len(class_names))
24
 
25
  # Load the state dictionary
26
- model_save_path = 'mobilenetv3_hymenoptera.pth'
 
 
27
  try:
28
  model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
29
  model.eval()
@@ -51,7 +53,7 @@ model = load_model()
51
  st.title("Ant vs. Bee Classifier")
52
  st.write("Upload an image to classify whether it's an ant or a bee.")
53
 
54
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
55
 
56
  if uploaded_file is not None:
57
  image = Image.open(uploaded_file).convert('RGB')
 
23
  model.classifier[3] = nn.Linear(num_ftrs, len(class_names))
24
 
25
  # Load the state dictionary
26
+ # model_save_path = 'mobilenetv3_hymenoptera.pth'
27
+ file_dir = os.path.dirname(os.path.abspath(__file__))
28
+ model_save_path = os.path.join(file_dir, 'mobilenetv3_hymenoptera.pth')
29
  try:
30
  model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
31
  model.eval()
 
53
  st.title("Ant vs. Bee Classifier")
54
  st.write("Upload an image to classify whether it's an ant or a bee.")
55
 
56
+ uploaded_file = st.file_uploader("Choose an image ...", type=["jpg", "jpeg", "png"])
57
 
58
  if uploaded_file is not None:
59
  image = Image.open(uploaded_file).convert('RGB')