File size: 7,087 Bytes
21ce0fd
 
8738fa4
 
 
 
 
3b3778d
8738fa4
 
 
 
 
 
dcd4c33
 
8738fa4
 
 
 
 
 
 
 
 
 
 
 
 
1cded8d
f82ef56
8738fa4
1cded8d
 
 
 
eca2ba9
1cded8d
 
 
8738fa4
1cded8d
8738fa4
 
 
 
 
 
 
 
 
 
 
 
f82ef56
8738fa4
 
 
 
 
 
 
 
eca2ba9
8738fa4
 
21facf8
8738fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d205550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8738fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e201d
 
 
d205550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e201d
8738fa4
 
 
 
 
 
 
bec7b2a
c6a3540
8738fa4
 
 
bec7b2a
8738fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33fe24f
 
a7c365c
 
e989d99
 
142610c
5069b01
66eab87
92f674b
c6a3540
7a28b6d
5069b01
 
 
 
 
 
 
 
 
c66dfd5
5069b01
7a28b6d
 
 
 
142610c
 
 
 
 
 
 
 
 
7a28b6d
 
 
 
 
 
 
 
bec7b2a
92f674b
 
 
 
7a28b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92f674b
7a28b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92f674b
7a28b6d
 
 
 
2436474
7a28b6d
 
 
142610c
5069b01
30e183e
92f674b
5069b01
3dd47cb
5069b01
 
30e183e
5069b01
e6b832e
 
7a28b6d
92f674b
 
6e5cdca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import streamlit as st

import requests

from PIL import Image

import numpy as np
import time 
import tensorflow as tf

from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

tf.config.set_visible_devices([], 'GPU')


# ---------------------------

# Helper Functions

# ---------------------------

 

def fetch_cat_image():

    """Fetches a random cat image from the CATAAS API."""

    response = requests.get('https://cataas.com/cat?json=true')
    print(f"STATUS_CODE:{response.status_code}")
    if response.status_code == 200:
        data = response.json()

        url = data['url']

        response_image = requests.get(url, stream=True)

        if response_image.status_code == 200:
            print(f"SECOND STATUS_CODE:{response_image.status_code}")

            return Image.open(response_image.raw)

    else:

        return None

 

def fetch_dog_image():

    """Fetches a random dog image from the Dog CEO API."""

    response = requests.get('https://dog.ceo/api/breeds/image/random')
    print(f"STATUS_CODE:{response.status_code}")
    if response.status_code == 200:

        data = response.json()

        if data['status'] == 'success':

            url = data['message']

            response_image = requests.get(url, stream=True)

            if response_image.status_code == 200:
                print(f"SECOND STATUS_CODE:{response_image.status_code}")

                return Image.open(response_image.raw)

    return None

 

def fetch_random_image():

    """Randomly fetches either a cat or a dog image."""

    if np.random.rand() < 0.5:

        return fetch_cat_image()

    else:

        return fetch_dog_image()

 

def preprocess_image(image, target_size=(64, 64)):

    """Resizes, normalizes, and ensures 3 channels for the image."""
    try:
        image = image.resize(target_size)
    
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        elif image.mode == 'P':
            image = image.convert('RGB')
    
        image = np.array(image) / 255.0  # Normalize to [0, 1]
    
        if len(image.shape) == 2:  # Convert grayscale to RGB
    
            image = np.stack([image] * 3, axis=-1)
    except Exception as e:
        print(f"Preprocessing error: {e}")
        return None

    return image

 

# ---------------------------

# Model Creation

# ---------------------------

 

def create_model(input_shape=(64, 64, 3)):

    """Creates a simple neural network with optional CNN layers."""

    model = Sequential([

        tf.keras.layers.Input(shape=input_shape),

        Conv2D(32, (3, 3), activation='relu'),

        MaxPooling2D((2, 2)),

        Conv2D(64, (3, 3), activation='relu'),

        MaxPooling2D((2, 2)),

        Flatten(),

        Dense(64, activation='relu'),

        Dense(1, activation='sigmoid')

    ])

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

def fetch_image():
    print("entering fetch_image")

    try:
        image = fetch_random_image()
    
        if image is not None:
    
            st.session_state.unprocessed_image = image
            st.session_state.current_image = preprocess_image(image)
    
            st.session_state.current_label = None
    
            st.session_state.current_prediction = None
            st.info("image ready")
    
        else:
    
            st.error("Failed to fetch an image. Please try again.")
    except Exception as e:
        st.error(f"An error occurred while fetching the image: {e}")


# ---------------------------

# Streamlit UI

# ---------------------------



 

st.title("Neural Network Classification Demo")
placeholder = st.empty()

 

# Initialize session state

if 'model' not in st.session_state:

    st.session_state.model = create_model()

    st.session_state.training_data = []

    st.session_state.current_image = None

    st.session_state.current_label = None

    st.session_state.current_prediction = None

    st.session_state.label_input = None

    st.session_state.started=False

    st.session_state.which_pic=1

    st.session_state.calculating=True
    st.session_state.next=True


    
# Button to fetch a new image
col1, col2, col3, col4 = st.columns([1,1,1,1])


with col1: 
    if st.button("Start",disabled=st.session_state.started):
        if not st.session_state.started:
            fetch_image()
            st.session_state.started=True
            st.session_state.calculating=False
            



# Display the current image


with col2:
    if st.button("cat",disabled=st.session_state.calculating):
        print("cat pressed")
        st.session_state.label_input="cat"
with col3:
    if st.button("dog",disabled=st.session_state.calculating):
        print("dog pressed")
        st.session_state.label_input="dog"
if st.session_state.current_image is not None:
    print(f"SHAPE:{st.session_state.current_image.shape}")
    prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0]

    st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat'
    st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)")

    st.image(st.session_state.unprocessed_image)
        
    
     
     
    
if st.session_state.label_input in ['cat', 'dog']:
    label_input=st.session_state.label_input
    st.session_state.label_input="None"

    # Convert user input to 0 (cat) or 1 (dog)
    print(f"LABEL CLICKED IS: {label_input.lower()}")
    label = 0 if label_input.lower() == 'cat' else 1

    st.session_state.current_label = label



    # Add the labeled image and label to training data

    st.session_state.training_data.append((st.session_state.current_image, label))



    # Retrain the model

    image = np.array([img for img, _ in st.session_state.training_data])

    label = np.array([lab for _, lab in st.session_state.training_data])

    



    # Predict the current image

    st.session_state.current_image=None

    print("before model fit")
    def model_fit():
        print("Entering model fit function")
        st.session_state.model.fit(image, label, epochs=1)
        st.write(st.session_state.model.evaluate(image,  label, verbose=2))
        st.session_state.calculating=True
        return
    model_fit()

    
    print("after model fit")
    #st.session_state.unprocessed_image=None
    print("before fetch_image")

    #fetch_image()
    print("after fetch_image")
    st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}")
    st.session_state.which_pic=st.session_state.which_pic+1
    
    st.session_state.next=False
    fetch_image()
    

with col4: 
    if st.button("next",disabled=st.session_state.next):
        if not st.session_state.next:
            
            st.session_state.next=True
            st.session_state.calculating=False
            st.rerun()