Spaces:
Sleeping
Sleeping
File size: 6,056 Bytes
8b68f8a 88d57d7 ce18868 88d57d7 edaa5bd e6b4a7c edaa5bd 472d1a7 2cbf422 e6b4a7c 45bf25d edaa5bd e6b4a7c 46f7e87 edaa5bd 496a79b e6b4a7c edaa5bd e6b4a7c e5de783 2cbf422 e6b4a7c 39a0cd2 496a79b c262c3c 45bf25d edaa5bd 45bf25d 3ea1fc5 45bf25d 540d089 45bf25d edaa5bd 45bf25d 540d089 dfe9e9c 45bf25d dfe9e9c 45bf25d f87990b edaa5bd 2cbf422 ca90588 2cbf422 472d1a7 3bb67ac 472d1a7 660013e 1cd2ad3 524afd6 1cd2ad3 6fa8650 f859df6 6fa8650 1cd2ad3 524afd6 1cd2ad3 62f4cca 0455e06 1cd2ad3 524afd6 1cd2ad3 6fa8650 f859df6 6fa8650 ef393c5 524afd6 ef393c5 62f4cca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | import streamlit as st
import torch
import torchvision
from torchvision import transforms
from PIL import Image
#####
###
# Initialization
###
#####
if 'generate_result' not in st.session_state:
st.session_state['generate_result'] = 0
if 'show_result' not in st.session_state:
st.session_state['show_result'] = 0
if 'number_of_files' not in st.session_state:
st.session_state['number_of_files'] = 0
if 'upload_choice' not in st.session_state:
st.session_state['upload_choice'] = 'file_up'
#####
###
# Used to show either the file_uploader or the webcam
###
#####
def change_state():
if st.session_state['upload_choice'] == 'file_up':
st.session_state['upload_choice'] = 'webcam'
else:
st.session_state['upload_choice'] = 'file_up'
# User toggle for file_uploader vs webcam
st.toggle(label="Webcam", help="Click on to use webcam, off to upload a file", on_change=change_state)
# Use state to know whether to show file_uploader or webcam
if st.session_state['upload_choice'] == 'file_up':
img = st.file_uploader(label="Upload a photo of a squirrel or bird", type=['png', 'jpg'])
if img is not None:
st.session_state['number_of_files'] = 1
else:
st.session_state['number_of_files'] = 0
else:
img = st.camera_input(label="Webcam")
if img is not None:
st.session_state['number_of_files'] = 1
else:
st.session_state['number_of_files'] = 0
#####
###
# Load the image and apply transformations
###
#####
def predict_image(image_path, model):
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
input_image = transform(image).unsqueeze(0) # Add batch dimension
# Move input tensor to the device (GPU if available)
input_image = input_image.to('cpu')
# Perform inference
model.eval()
with torch.no_grad():
output = model(input_image)
# Get predicted class probabilities and class index
probabilities = torch.softmax(output, dim=1)[0]
predicted_class_index = torch.argmax(probabilities).item()
# Map class index to class label
class_labels = ["Bird", "Squirrel"]
predicted_class_label = class_labels[predicted_class_index]
return predicted_class_label
# print("Class probabilities:")
# for i, prob in enumerate(probabilities):
# print(f"{class_labels[i]}: {prob:.4f}")
#####
###
# Load model and prepare for inference
###
#####
model_loaded = torchvision.models.resnet18(pretrained=False) # Initialize ResNet18 without pretraining
model_loaded.fc = torch.nn.Linear(model_loaded.fc.in_features, 2) # Modify the fully connected layer
model_loaded = model_loaded.to('cpu') # Move the model to the appropriate device (GPU or CPU)
# Load the saved state dictionary into the model
model_path = 'resnet18_custom_model.pth'
model_loaded.load_state_dict(torch.load(model_path, map_location='cpu'))
# Set the model to evaluation mode
model_loaded.eval()
#####
###
# Toggle view of model output in UI
###
#####
if st.session_state['upload_choice'] == 'file_up' and st.session_state['number_of_files'] == 1:
st.session_state['generate_result'] = 1
st.session_state['show_result'] = 1
elif st.session_state['upload_choice'] == 'webcam' and st.session_state['number_of_files'] == 1:
st.session_state['generate_result'] = 1
st.session_state['show_result'] = 1
else:
st.session_state['generate_result'] = 0
st.session_state['show_result'] = 0
if st.session_state['generate_result'] != 0:
if img is not None:
result = predict_image(image_path=img, model=model_loaded)
st.session_state['generate_result'] = 0
if st.session_state['show_result'] != 0:
if result == 'Bird':
st.markdown("""
<style>
.centered {
text-align: center;
}
</style>
<div class="centered">
๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ
</div>
""", unsafe_allow_html=True)
st.markdown("""
<style>
.big-font {
font-size:30px !important;
text-align: center;
}
</style>
<div class="big-font">
That's a Bird
</div>
""", unsafe_allow_html=True)
st.markdown("""
<style>
.centered {
text-align: center;
}
</style>
<div class="centered">
๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ๐ฆ
</div>
""", unsafe_allow_html=True)
if st.session_state['upload_choice'] == 'file_up':
st.image(img)
else:
st.markdown("""
<style>
.centered {
text-align: center;
}
</style>
<div class="centered">
๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ
</div>
""", unsafe_allow_html=True)
st.markdown("""
<style>
.big-font {
font-size:30px !important;
text-align: center;
}
</style>
<div class="big-font">
That's a Squirrel
</div>
""", unsafe_allow_html=True)
st.markdown("""
<style>
.centered {
text-align: center;
}
</style>
<div class="centered">
๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ๐ฟ๏ธ
</div>
""", unsafe_allow_html=True)
if st.session_state['upload_choice'] == 'file_up':
st.image(img) |