File size: 6,056 Bytes
8b68f8a
88d57d7
ce18868
88d57d7
 
 
edaa5bd
 
e6b4a7c
edaa5bd
 
472d1a7
 
 
 
2cbf422
 
e6b4a7c
 
45bf25d
edaa5bd
 
 
 
 
 
e6b4a7c
46f7e87
 
 
 
 
edaa5bd
496a79b
e6b4a7c
edaa5bd
e6b4a7c
 
e5de783
2cbf422
 
 
e6b4a7c
39a0cd2
496a79b
c262c3c
 
 
45bf25d
edaa5bd
 
 
 
 
 
45bf25d
 
 
 
 
 
 
 
 
 
 
 
3ea1fc5
45bf25d
 
 
 
 
 
 
 
 
 
 
540d089
45bf25d
 
 
 
 
 
 
 
edaa5bd
 
 
 
 
45bf25d
540d089
dfe9e9c
45bf25d
 
 
dfe9e9c
45bf25d
 
 
f87990b
edaa5bd
 
 
 
 
 
2cbf422
 
 
ca90588
 
 
 
 
 
 
 
2cbf422
472d1a7
 
 
 
3bb67ac
472d1a7
660013e
1cd2ad3
 
 
 
 
 
 
524afd6
1cd2ad3
 
6fa8650
 
 
 
 
 
 
 
f859df6
6fa8650
 
1cd2ad3
 
 
 
 
 
 
524afd6
1cd2ad3
 
62f4cca
 
0455e06
1cd2ad3
 
 
 
 
 
 
524afd6
1cd2ad3
 
6fa8650
 
 
 
 
 
 
 
f859df6
6fa8650
 
ef393c5
 
 
 
 
 
 
524afd6
ef393c5
62f4cca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import streamlit as st
import torch 
import torchvision
from torchvision import transforms
from PIL import Image

#####
###
# Initialization
###
#####
if 'generate_result' not in st.session_state:
    st.session_state['generate_result'] = 0
if 'show_result' not in st.session_state:
    st.session_state['show_result'] = 0
if 'number_of_files' not in st.session_state:
    st.session_state['number_of_files'] = 0
if 'upload_choice' not in st.session_state:
    st.session_state['upload_choice'] = 'file_up'


#####
###
# Used to show either the file_uploader or the webcam
###
#####
def change_state():
    if st.session_state['upload_choice'] == 'file_up':
        st.session_state['upload_choice'] = 'webcam'
    else:
        st.session_state['upload_choice'] = 'file_up'

# User toggle for file_uploader vs webcam
st.toggle(label="Webcam",  help="Click on to use webcam, off to upload a file", on_change=change_state)   

# Use state to know whether to show file_uploader or webcam
if st.session_state['upload_choice'] == 'file_up':
    img = st.file_uploader(label="Upload a photo of a squirrel or bird", type=['png', 'jpg'])
    if img is not None:
        st.session_state['number_of_files'] = 1
    else:
        st.session_state['number_of_files'] = 0
else:
    img = st.camera_input(label="Webcam")
    if img is not None:
        st.session_state['number_of_files'] = 1
    else:
        st.session_state['number_of_files'] = 0


#####
###
# Load the image and apply transformations
###
#####
def predict_image(image_path, model):

    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    input_image = transform(image).unsqueeze(0)  # Add batch dimension

    # Move input tensor to the device (GPU if available)
    input_image = input_image.to('cpu')

    # Perform inference
    model.eval()
    with torch.no_grad():
        output = model(input_image)

    # Get predicted class probabilities and class index
    probabilities = torch.softmax(output, dim=1)[0]
    predicted_class_index = torch.argmax(probabilities).item()

    # Map class index to class label
    class_labels = ["Bird", "Squirrel"]
    predicted_class_label = class_labels[predicted_class_index]

    return predicted_class_label
#     print("Class probabilities:")
#     for i, prob in enumerate(probabilities):
#         print(f"{class_labels[i]}: {prob:.4f}")


#####
###
# Load model and prepare for inference
###
#####
model_loaded = torchvision.models.resnet18(pretrained=False)  # Initialize ResNet18 without pretraining 
model_loaded.fc = torch.nn.Linear(model_loaded.fc.in_features, 2)  # Modify the fully connected layer
model_loaded = model_loaded.to('cpu')  # Move the model to the appropriate device (GPU or CPU)

# Load the saved state dictionary into the model
model_path = 'resnet18_custom_model.pth'
model_loaded.load_state_dict(torch.load(model_path, map_location='cpu'))

# Set the model to evaluation mode
model_loaded.eval()


#####
###
# Toggle view of model output in UI
###
#####
if st.session_state['upload_choice'] == 'file_up' and st.session_state['number_of_files'] == 1:
    st.session_state['generate_result'] = 1
    st.session_state['show_result'] = 1
elif st.session_state['upload_choice'] == 'webcam' and st.session_state['number_of_files'] == 1:
    st.session_state['generate_result'] = 1
    st.session_state['show_result'] = 1   
else:
    st.session_state['generate_result'] = 0
    st.session_state['show_result'] = 0



if st.session_state['generate_result'] != 0:
    if img is not None:
        result = predict_image(image_path=img, model=model_loaded)
    st.session_state['generate_result'] = 0

if st.session_state['show_result'] != 0:
    if result == 'Bird':
        st.markdown("""
            <style>
            .centered {
                text-align: center;
            }
            </style>
            <div class="centered">
                ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ
            </div>
            """, unsafe_allow_html=True)
        st.markdown("""
            <style>
            .big-font {
                font-size:30px !important;
                text-align: center;
            }
            </style>
            <div class="big-font">
            That's a Bird
            </div>
        """, unsafe_allow_html=True)
        st.markdown("""
            <style>
            .centered {
                text-align: center;
            }
            </style>
            <div class="centered">
                ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ๐Ÿฆ
            </div>
            """, unsafe_allow_html=True)
        if st.session_state['upload_choice'] == 'file_up':
            st.image(img)
    else:
        st.markdown("""
            <style>
            .centered {
                text-align: center;
            }
            </style>
            <div class="centered">
                ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ
            </div>
            """, unsafe_allow_html=True)
        st.markdown("""
            <style>
            .big-font {
                font-size:30px !important;
                text-align: center;
            }
            </style>
            <div class="big-font">
            That's a Squirrel
            </div>
        """, unsafe_allow_html=True)
        st.markdown("""
            <style>
            .centered {
                text-align: center;
            }
            </style>
            <div class="centered">
                ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ๐Ÿฟ๏ธ
            </div>
            """, unsafe_allow_html=True)
        if st.session_state['upload_choice'] == 'file_up':
            st.image(img)