nebi commited on
Commit
c6a3540
·
verified ·
1 Parent(s): e30857d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +91 -90
src/streamlit_app.py CHANGED
@@ -13,7 +13,6 @@ from tensorflow.keras.models import Sequential
13
  from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
14
 
15
  tf.config.set_visible_devices([], 'GPU')
16
- st.write("TensorFlow version:", tf.__version__)
17
 
18
 
19
  # ---------------------------
@@ -170,6 +169,8 @@ def fetch_image():
170
 
171
  # ---------------------------
172
 
 
 
173
 
174
 
175
  st.title("Neural Network Classification Demo")
@@ -199,100 +200,100 @@ if 'model' not in st.session_state:
199
  st.session_state.calculating=False
200
 
201
 
 
202
 
203
-
204
- # Button to fetch a new image
205
- col1, col2, col3 = st.columns([1,1,1])
206
-
207
- if st.session_state.current_image is not None and st.session_state.started is False:
208
- st.session_state.started=True
209
-
210
- with col1:
211
- if st.button("Start",disabled=st.session_state.started):
212
- if not st.session_state.started:
213
- fetch_image()
214
-
215
-
216
- # Display the current image
217
-
218
- if st.session_state.current_image is not None:
219
- with col2:
220
- if st.button("cat"):#,disabled=st.session_state.calculating):
221
- print("cat pressed")
222
- st.session_state.label_input="cat"
223
- with col3:
224
- if st.button("dog"):#,disabled=st.session_state.calculating):
225
- print("dog pressed")
226
- st.session_state.label_input="dog"
227
- print(f"SHAPE:{st.session_state.current_image.shape}")
228
- prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0]
229
-
230
- st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat'
231
- st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)")
232
-
233
- st.image(st.session_state.unprocessed_image)
234
 
235
-
236
-
237
-
238
- # User input for label
239
-
240
-
241
- if st.session_state.label_input in ['cat', 'dog']:
242
- label_input=st.session_state.label_input
243
- st.session_state.label_input="None"
244
-
245
- # Convert user input to 0 (cat) or 1 (dog)
246
- print(f"LABEL CLICKED IS: {label_input.lower()}")
247
- label = 0 if label_input.lower() == 'cat' else 1
248
-
249
- st.session_state.current_label = label
250
-
251
-
252
-
253
- # Add the labeled image and label to training data
254
-
255
- st.session_state.training_data.append((st.session_state.current_image, label))
256
-
257
-
258
-
259
- # Retrain the model
260
-
261
- image = np.array([img for img, _ in st.session_state.training_data])
262
-
263
- label = np.array([lab for _, lab in st.session_state.training_data])
264
-
265
 
266
-
267
-
268
-
269
- # Predict the current image
270
-
271
- st.session_state.current_image=None
272
-
273
- print("before model fit")
274
- def model_fit():
275
- print("Entering model fit function")
276
- st.session_state.model.fit(image, label, epochs=1)
277
- st.write(st.session_state.model.evaluate(image, label, verbose=2))
278
- st.session_state.calculating=True
279
- st.write("hi")
280
- return
281
- model_fit()
282
-
283
- if st.session_state.calculating:
284
 
285
- print("after model fit")
286
- #st.session_state.unprocessed_image=None
287
- print("before fetch_image")
288
-
289
- fetch_image()
290
- print("after fetch_image")
291
- st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}")
292
- st.session_state.which_pic=st.session_state.which_pic+1
293
- st.session_state.calculating=False
294
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
 
298
 
 
13
  from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
14
 
15
  tf.config.set_visible_devices([], 'GPU')
 
16
 
17
 
18
  # ---------------------------
 
169
 
170
  # ---------------------------
171
 
172
+ placeholder = st.empty()
173
+
174
 
175
 
176
  st.title("Neural Network Classification Demo")
 
200
  st.session_state.calculating=False
201
 
202
 
203
+ with placeholder.container():
204
 
205
+ # Button to fetch a new image
206
+ col1, col2, col3 = st.columns([1,1,1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ if st.session_state.current_image is not None and st.session_state.started is False:
209
+ st.session_state.started=True
210
+
211
+ with col1:
212
+ if st.button("Start",disabled=st.session_state.started):
213
+ if not st.session_state.started:
214
+ fetch_image()
215
+
216
+
217
+ # Display the current image
218
+
219
+ if st.session_state.current_image is not None:
220
+ with col2:
221
+ if st.button("cat"):#,disabled=st.session_state.calculating):
222
+ print("cat pressed")
223
+ st.session_state.label_input="cat"
224
+ with col3:
225
+ if st.button("dog"):#,disabled=st.session_state.calculating):
226
+ print("dog pressed")
227
+ st.session_state.label_input="dog"
228
+ print(f"SHAPE:{st.session_state.current_image.shape}")
229
+ prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0]
230
+
231
+ st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat'
232
+ st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)")
233
+
234
+ st.image(st.session_state.unprocessed_image)
 
 
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+
238
+
239
+ # User input for label
240
+
241
+
242
+ if st.session_state.label_input in ['cat', 'dog']:
243
+ label_input=st.session_state.label_input
244
+ st.session_state.label_input="None"
245
+
246
+ # Convert user input to 0 (cat) or 1 (dog)
247
+ print(f"LABEL CLICKED IS: {label_input.lower()}")
248
+ label = 0 if label_input.lower() == 'cat' else 1
249
+
250
+ st.session_state.current_label = label
251
+
252
+
253
+
254
+ # Add the labeled image and label to training data
255
+
256
+ st.session_state.training_data.append((st.session_state.current_image, label))
257
+
258
+
259
+
260
+ # Retrain the model
261
+
262
+ image = np.array([img for img, _ in st.session_state.training_data])
263
+
264
+ label = np.array([lab for _, lab in st.session_state.training_data])
265
+
266
+
267
+
268
+
269
+
270
+ # Predict the current image
271
+
272
+ st.session_state.current_image=None
273
+
274
+ print("before model fit")
275
+ def model_fit():
276
+ print("Entering model fit function")
277
+ st.session_state.model.fit(image, label, epochs=1)
278
+ st.write(st.session_state.model.evaluate(image, label, verbose=2))
279
+ st.session_state.calculating=True
280
+ st.write("hi")
281
+ return
282
+ model_fit()
283
+
284
+ if st.session_state.calculating:
285
 
286
+ print("after model fit")
287
+ #st.session_state.unprocessed_image=None
288
+ print("before fetch_image")
289
+
290
+ fetch_image()
291
+ print("after fetch_image")
292
+ st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}")
293
+ st.session_state.which_pic=st.session_state.which_pic+1
294
+ st.session_state.calculating=False
295
+
296
+
297
 
298
 
299