nebi commited on
Commit
bec7b2a
·
verified ·
1 Parent(s): c6a3540

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +93 -92
src/streamlit_app.py CHANGED
@@ -169,11 +169,12 @@ def fetch_image():
169
 
170
  # ---------------------------
171
 
172
- placeholder = st.empty()
173
 
174
 
175
 
176
  st.title("Neural Network Classification Demo")
 
177
 
178
 
179
 
@@ -199,102 +200,102 @@ if 'model' not in st.session_state:
199
 
200
  st.session_state.calculating=False
201
 
202
-
203
- with placeholder.container():
204
-
205
- # Button to fetch a new image
206
- col1, col2, col3 = st.columns([1,1,1])
207
-
208
- if st.session_state.current_image is not None and st.session_state.started is False:
209
- st.session_state.started=True
210
-
211
- with col1:
212
- if st.button("Start",disabled=st.session_state.started):
213
- if not st.session_state.started:
214
- fetch_image()
215
-
216
 
217
- # Display the current image
218
-
219
- if st.session_state.current_image is not None:
220
- with col2:
221
- if st.button("cat"):#,disabled=st.session_state.calculating):
222
- print("cat pressed")
223
- st.session_state.label_input="cat"
224
- with col3:
225
- if st.button("dog"):#,disabled=st.session_state.calculating):
226
- print("dog pressed")
227
- st.session_state.label_input="dog"
228
- print(f"SHAPE:{st.session_state.current_image.shape}")
229
- prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0]
230
-
231
- st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat'
232
- st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)")
233
-
234
- st.image(st.session_state.unprocessed_image)
235
 
236
-
237
-
238
-
239
- # User input for label
240
-
241
-
242
- if st.session_state.label_input in ['cat', 'dog']:
243
- label_input=st.session_state.label_input
244
- st.session_state.label_input="None"
245
-
246
- # Convert user input to 0 (cat) or 1 (dog)
247
- print(f"LABEL CLICKED IS: {label_input.lower()}")
248
- label = 0 if label_input.lower() == 'cat' else 1
249
-
250
- st.session_state.current_label = label
251
-
252
-
253
-
254
- # Add the labeled image and label to training data
255
-
256
- st.session_state.training_data.append((st.session_state.current_image, label))
257
-
258
-
259
-
260
- # Retrain the model
261
-
262
- image = np.array([img for img, _ in st.session_state.training_data])
263
-
264
- label = np.array([lab for _, lab in st.session_state.training_data])
265
-
266
 
267
-
268
-
269
-
270
- # Predict the current image
271
-
272
- st.session_state.current_image=None
273
-
274
- print("before model fit")
275
- def model_fit():
276
- print("Entering model fit function")
277
- st.session_state.model.fit(image, label, epochs=1)
278
- st.write(st.session_state.model.evaluate(image, label, verbose=2))
279
- st.session_state.calculating=True
280
- st.write("hi")
281
- return
282
- model_fit()
283
-
284
- if st.session_state.calculating:
285
 
286
- print("after model fit")
287
- #st.session_state.unprocessed_image=None
288
- print("before fetch_image")
289
-
290
- fetch_image()
291
- print("after fetch_image")
292
- st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}")
293
- st.session_state.which_pic=st.session_state.which_pic+1
294
- st.session_state.calculating=False
295
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
-
 
 
 
 
 
 
 
 
 
 
 
298
 
299
 
300
 
 
169
 
170
  # ---------------------------
171
 
172
+
173
 
174
 
175
 
176
  st.title("Neural Network Classification Demo")
177
+ placeholder = st.empty()
178
 
179
 
180
 
 
200
 
201
  st.session_state.calculating=False
202
 
203
+ while True:
204
+ with placeholder.container():
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ # Button to fetch a new image
207
+ col1, col2, col3 = st.columns([1,1,1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ if st.session_state.current_image is not None and st.session_state.started is False:
210
+ st.session_state.started=True
211
+
212
+ with col1:
213
+ if st.button("Start",disabled=st.session_state.started):
214
+ if not st.session_state.started:
215
+ fetch_image()
216
+
217
+
218
+ # Display the current image
219
+
220
+ if st.session_state.current_image is not None:
221
+ with col2:
222
+ if st.button("cat"):#,disabled=st.session_state.calculating):
223
+ print("cat pressed")
224
+ st.session_state.label_input="cat"
225
+ with col3:
226
+ if st.button("dog"):#,disabled=st.session_state.calculating):
227
+ print("dog pressed")
228
+ st.session_state.label_input="dog"
229
+ print(f"SHAPE:{st.session_state.current_image.shape}")
230
+ prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0]
231
+
232
+ st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat'
233
+ st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)")
234
+
235
+ st.image(st.session_state.unprocessed_image)
 
 
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+
239
+
240
+ # User input for label
241
+
242
+
243
+ if st.session_state.label_input in ['cat', 'dog']:
244
+ label_input=st.session_state.label_input
245
+ st.session_state.label_input="None"
246
+
247
+ # Convert user input to 0 (cat) or 1 (dog)
248
+ print(f"LABEL CLICKED IS: {label_input.lower()}")
249
+ label = 0 if label_input.lower() == 'cat' else 1
250
+
251
+ st.session_state.current_label = label
252
+
253
+
254
+
255
+ # Add the labeled image and label to training data
256
+
257
+ st.session_state.training_data.append((st.session_state.current_image, label))
258
+
259
+
260
+
261
+ # Retrain the model
262
+
263
+ image = np.array([img for img, _ in st.session_state.training_data])
264
+
265
+ label = np.array([lab for _, lab in st.session_state.training_data])
266
+
267
+
268
+
269
+
270
+
271
+ # Predict the current image
272
+
273
+ st.session_state.current_image=None
274
+
275
+ print("before model fit")
276
+ def model_fit():
277
+ print("Entering model fit function")
278
+ st.session_state.model.fit(image, label, epochs=1)
279
+ st.write(st.session_state.model.evaluate(image, label, verbose=2))
280
+ st.session_state.calculating=True
281
+ st.write("hi")
282
+ return
283
+ model_fit()
284
+
285
+ if st.session_state.calculating:
286
 
287
+ print("after model fit")
288
+ #st.session_state.unprocessed_image=None
289
+ print("before fetch_image")
290
+
291
+ fetch_image()
292
+ print("after fetch_image")
293
+ st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}")
294
+ st.session_state.which_pic=st.session_state.which_pic+1
295
+ st.session_state.calculating=False
296
+
297
+
298
+ time.sleep(1)
299
 
300
 
301