File size: 1,984 Bytes
11a28bc
 
67d3aa0
11a28bc
 
 
523a6d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11a28bc
 
 
 
 
523a6d5
11a28bc
 
 
523a6d5
 
 
f1e3cf2
11a28bc
523a6d5
 
11a28bc
523a6d5
cd78198
 
67d3aa0
 
cd78198
f1e3cf2
67d3aa0
 
 
 
 
cd78198
 
 
67d3aa0
11a28bc
523a6d5
67d3aa0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
import numpy as np
from tensorflow.keras.preprocessing import image
from keras.models import load_model


# Define the dictionary of classes and load the model
CLASSES = {
    'french_bulldog': 0,
    'german_shepherd': 1,
    'golden_retriever': 2,
    'poodle': 3,
    'yorkshire_terrier': 4
}
MODEL_PATH = 'best_model.h5'
model = load_model(MODEL_PATH)


# Define a function to make predictions on a given image
def predict_breed(image_file):
    img = image.load_img(image_file, target_size=(256, 256))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = x / 255.
    preds = model.predict(x)
    class_idx = np.argmax(preds[0])
    predicted_class = [k for k, v in CLASSES.items() if v == class_idx][0]
    return predicted_class


# Create the Streamlit app
def main():
    st.title('Dog Breed Classification')
    uploaded_file = st.file_uploader('Choose an image of a dog', key='file_uploader', type=['jpg', 'jpeg', 'png'])
    if uploaded_file is not None:
        image_file = uploaded_file.name
        with open(image_file, 'wb') as f:
            f.write(uploaded_file.getbuffer())
        predicted_class = predict_breed(image_file)
        st.image(uploaded_file, caption='', use_column_width=True)
        st.write(f'Predicted class: **{predicted_class}**', unsafe_allow_html=True)

    # Add a button to trigger image upload
    if st.button('Upload and Predict '):
        uploaded_file = st.file_uploader('Choose an image of a dog', key='file_uploader_2', type=['jpg', 'jpeg', 'png'])
        if uploaded_file is not None:
            image_file = uploaded_file.name
            with open(image_file, 'wb') as f:
                f.write(uploaded_file.getbuffer())
            predicted_class = predict_breed(image_file)
            st.image(uploaded_file, caption='', use_column_width=True)
            st.write(f'Predicted class: **{predicted_class}**', unsafe_allow_html=True)


if __name__ == '__main__':
    main()