cat_dog_learn_nn / src /streamlit_app.py
nebi's picture
Update src/streamlit_app.py
30e183e verified
import streamlit as st
import requests
from PIL import Image
import numpy as np
import time
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
tf.config.set_visible_devices([], 'GPU')
# ---------------------------
# Helper Functions
# ---------------------------
def fetch_cat_image():
"""Fetches a random cat image from the CATAAS API."""
response = requests.get('https://cataas.com/cat?json=true')
print(f"STATUS_CODE:{response.status_code}")
if response.status_code == 200:
data = response.json()
url = data['url']
response_image = requests.get(url, stream=True)
if response_image.status_code == 200:
print(f"SECOND STATUS_CODE:{response_image.status_code}")
return Image.open(response_image.raw)
else:
return None
def fetch_dog_image():
"""Fetches a random dog image from the Dog CEO API."""
response = requests.get('https://dog.ceo/api/breeds/image/random')
print(f"STATUS_CODE:{response.status_code}")
if response.status_code == 200:
data = response.json()
if data['status'] == 'success':
url = data['message']
response_image = requests.get(url, stream=True)
if response_image.status_code == 200:
print(f"SECOND STATUS_CODE:{response_image.status_code}")
return Image.open(response_image.raw)
return None
def fetch_random_image():
"""Randomly fetches either a cat or a dog image."""
if np.random.rand() < 0.5:
return fetch_cat_image()
else:
return fetch_dog_image()
def preprocess_image(image, target_size=(64, 64)):
"""Resizes, normalizes, and ensures 3 channels for the image."""
try:
image = image.resize(target_size)
if image.mode == 'RGBA':
image = image.convert('RGB')
elif image.mode == 'P':
image = image.convert('RGB')
image = np.array(image) / 255.0 # Normalize to [0, 1]
if len(image.shape) == 2: # Convert grayscale to RGB
image = np.stack([image] * 3, axis=-1)
except Exception as e:
print(f"Preprocessing error: {e}")
return None
return image
# ---------------------------
# Model Creation
# ---------------------------
def create_model(input_shape=(64, 64, 3)):
"""Creates a simple neural network with optional CNN layers."""
model = Sequential([
tf.keras.layers.Input(shape=input_shape),
Conv2D(32, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(),
Dense(64, activation='relu'),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
def fetch_image():
print("entering fetch_image")
try:
image = fetch_random_image()
if image is not None:
st.session_state.unprocessed_image = image
st.session_state.current_image = preprocess_image(image)
st.session_state.current_label = None
st.session_state.current_prediction = None
st.info("image ready")
else:
st.error("Failed to fetch an image. Please try again.")
except Exception as e:
st.error(f"An error occurred while fetching the image: {e}")
# ---------------------------
# Streamlit UI
# ---------------------------
st.title("Neural Network Classification Demo")
placeholder = st.empty()
# Initialize session state
if 'model' not in st.session_state:
st.session_state.model = create_model()
st.session_state.training_data = []
st.session_state.current_image = None
st.session_state.current_label = None
st.session_state.current_prediction = None
st.session_state.label_input = None
st.session_state.started=False
st.session_state.which_pic=1
st.session_state.calculating=True
st.session_state.next=True
# Button to fetch a new image
col1, col2, col3, col4 = st.columns([1,1,1,1])
with col1:
if st.button("Start",disabled=st.session_state.started):
if not st.session_state.started:
fetch_image()
st.session_state.started=True
st.session_state.calculating=False
# Display the current image
with col2:
if st.button("cat",disabled=st.session_state.calculating):
print("cat pressed")
st.session_state.label_input="cat"
with col3:
if st.button("dog",disabled=st.session_state.calculating):
print("dog pressed")
st.session_state.label_input="dog"
if st.session_state.current_image is not None:
print(f"SHAPE:{st.session_state.current_image.shape}")
prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0]
st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat'
st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)")
st.image(st.session_state.unprocessed_image)
if st.session_state.label_input in ['cat', 'dog']:
label_input=st.session_state.label_input
st.session_state.label_input="None"
# Convert user input to 0 (cat) or 1 (dog)
print(f"LABEL CLICKED IS: {label_input.lower()}")
label = 0 if label_input.lower() == 'cat' else 1
st.session_state.current_label = label
# Add the labeled image and label to training data
st.session_state.training_data.append((st.session_state.current_image, label))
# Retrain the model
image = np.array([img for img, _ in st.session_state.training_data])
label = np.array([lab for _, lab in st.session_state.training_data])
# Predict the current image
st.session_state.current_image=None
print("before model fit")
def model_fit():
print("Entering model fit function")
st.session_state.model.fit(image, label, epochs=1)
st.write(st.session_state.model.evaluate(image, label, verbose=2))
st.session_state.calculating=True
return
model_fit()
print("after model fit")
#st.session_state.unprocessed_image=None
print("before fetch_image")
#fetch_image()
print("after fetch_image")
st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}")
st.session_state.which_pic=st.session_state.which_pic+1
st.session_state.next=False
fetch_image()
with col4:
if st.button("next",disabled=st.session_state.next):
if not st.session_state.next:
st.session_state.next=True
st.session_state.calculating=False
st.rerun()