JaredBailey commited on
Commit
edaa5bd
·
verified ·
1 Parent(s): 472d1a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -3
app.py CHANGED
@@ -4,8 +4,11 @@ import torchvision
4
  from torchvision import transforms
5
  from PIL import Image
6
 
7
-
 
8
  # Initialization
 
 
9
  if 'generate_result' not in st.session_state:
10
  st.session_state['generate_result'] = 0
11
  if 'show_result' not in st.session_state:
@@ -13,21 +16,33 @@ if 'show_result' not in st.session_state:
13
  if 'upload_choice' not in st.session_state:
14
  st.session_state['upload_choice'] = 'file_up'
15
 
 
 
 
 
 
 
16
  def change_state():
17
  if st.session_state['upload_choice'] == 'file_up':
18
  st.session_state['upload_choice'] = 'webcam'
19
  else:
20
  st.session_state['upload_choice'] = 'file_up'
21
 
22
-
23
  st.toggle(label="Webcam", help="Click on to use webcam, off to upload a file", on_change=change_state)
24
 
 
25
  if st.session_state['upload_choice'] == 'file_up':
26
  img = st.file_uploader(label="Upload a photo of a squirrel or bird", type=['png', 'jpg'])
27
  else:
28
  st.camera_input(label="Webcam")
29
- # Load the image and apply transformations
30
 
 
 
 
 
 
 
31
  def predict_image(image_path, model):
32
 
33
  image = Image.open(image_path).convert('RGB')
@@ -61,6 +76,11 @@ def predict_image(image_path, model):
61
  # print(f"{class_labels[i]}: {prob:.4f}")
62
 
63
 
 
 
 
 
 
64
  model_loaded = torchvision.models.resnet18(pretrained=False) # Initialize ResNet18 without pretraining
65
  model_loaded.fc = torch.nn.Linear(model_loaded.fc.in_features, 2) # Modify the fully connected layer
66
  model_loaded = model_loaded.to('cpu') # Move the model to the appropriate device (GPU or CPU)
@@ -72,6 +92,12 @@ model_loaded.load_state_dict(torch.load(model_path, map_location='cpu'))
72
  # Set the model to evaluation mode
73
  model_loaded.eval()
74
 
 
 
 
 
 
 
75
  if st.session_state['generate_result'] != 0:
76
  if img is not None:
77
  result = predict_image(image_path=img, model=model_loaded)
 
4
  from torchvision import transforms
5
  from PIL import Image
6
 
7
+ #####
8
+ ###
9
  # Initialization
10
+ ###
11
+ #####
12
  if 'generate_result' not in st.session_state:
13
  st.session_state['generate_result'] = 0
14
  if 'show_result' not in st.session_state:
 
16
  if 'upload_choice' not in st.session_state:
17
  st.session_state['upload_choice'] = 'file_up'
18
 
19
+
20
+ #####
21
+ ###
22
+ # Used to show either the file_uploader or the webcam
23
+ ###
24
+ #####
25
  def change_state():
26
  if st.session_state['upload_choice'] == 'file_up':
27
  st.session_state['upload_choice'] = 'webcam'
28
  else:
29
  st.session_state['upload_choice'] = 'file_up'
30
 
31
+ # User toggle for file_uploader vs webcam
32
  st.toggle(label="Webcam", help="Click on to use webcam, off to upload a file", on_change=change_state)
33
 
34
+ # Use state to know whether to show file_uploader or webcam
35
  if st.session_state['upload_choice'] == 'file_up':
36
  img = st.file_uploader(label="Upload a photo of a squirrel or bird", type=['png', 'jpg'])
37
  else:
38
  st.camera_input(label="Webcam")
 
39
 
40
+
41
+ #####
42
+ ###
43
+ # Load the image and apply transformations
44
+ ###
45
+ #####
46
  def predict_image(image_path, model):
47
 
48
  image = Image.open(image_path).convert('RGB')
 
76
  # print(f"{class_labels[i]}: {prob:.4f}")
77
 
78
 
79
+ #####
80
+ ###
81
+ # Load model and prepare for inference
82
+ ###
83
+ #####
84
  model_loaded = torchvision.models.resnet18(pretrained=False) # Initialize ResNet18 without pretraining
85
  model_loaded.fc = torch.nn.Linear(model_loaded.fc.in_features, 2) # Modify the fully connected layer
86
  model_loaded = model_loaded.to('cpu') # Move the model to the appropriate device (GPU or CPU)
 
92
  # Set the model to evaluation mode
93
  model_loaded.eval()
94
 
95
+
96
+ #####
97
+ ###
98
+ # Toggle view of model output in UI
99
+ ###
100
+ #####
101
  if st.session_state['generate_result'] != 0:
102
  if img is not None:
103
  result = predict_image(image_path=img, model=model_loaded)