import streamlit as st import requests from PIL import Image import numpy as np import time import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense tf.config.set_visible_devices([], 'GPU') # --------------------------- # Helper Functions # --------------------------- def fetch_cat_image(): """Fetches a random cat image from the CATAAS API.""" response = requests.get('https://cataas.com/cat?json=true') print(f"STATUS_CODE:{response.status_code}") if response.status_code == 200: data = response.json() url = data['url'] response_image = requests.get(url, stream=True) if response_image.status_code == 200: print(f"SECOND STATUS_CODE:{response_image.status_code}") return Image.open(response_image.raw) else: return None def fetch_dog_image(): """Fetches a random dog image from the Dog CEO API.""" response = requests.get('https://dog.ceo/api/breeds/image/random') print(f"STATUS_CODE:{response.status_code}") if response.status_code == 200: data = response.json() if data['status'] == 'success': url = data['message'] response_image = requests.get(url, stream=True) if response_image.status_code == 200: print(f"SECOND STATUS_CODE:{response_image.status_code}") return Image.open(response_image.raw) return None def fetch_random_image(): """Randomly fetches either a cat or a dog image.""" if np.random.rand() < 0.5: return fetch_cat_image() else: return fetch_dog_image() def preprocess_image(image, target_size=(64, 64)): """Resizes, normalizes, and ensures 3 channels for the image.""" try: image = image.resize(target_size) if image.mode == 'RGBA': image = image.convert('RGB') elif image.mode == 'P': image = image.convert('RGB') image = np.array(image) / 255.0 # Normalize to [0, 1] if len(image.shape) == 2: # Convert grayscale to RGB image = np.stack([image] * 3, axis=-1) except Exception as e: print(f"Preprocessing error: {e}") return None return image # --------------------------- # Model Creation # --------------------------- def create_model(input_shape=(64, 64, 3)): """Creates a simple neural network with optional CNN layers.""" model = Sequential([ tf.keras.layers.Input(shape=input_shape), Conv2D(32, (3, 3), activation='relu'), MaxPooling2D((2, 2)), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D((2, 2)), Flatten(), Dense(64, activation='relu'), Dense(1, activation='sigmoid') ]) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) return model def fetch_image(): print("entering fetch_image") try: image = fetch_random_image() if image is not None: st.session_state.unprocessed_image = image st.session_state.current_image = preprocess_image(image) st.session_state.current_label = None st.session_state.current_prediction = None st.info("image ready") else: st.error("Failed to fetch an image. Please try again.") except Exception as e: st.error(f"An error occurred while fetching the image: {e}") # --------------------------- # Streamlit UI # --------------------------- st.title("Neural Network Classification Demo") placeholder = st.empty() # Initialize session state if 'model' not in st.session_state: st.session_state.model = create_model() st.session_state.training_data = [] st.session_state.current_image = None st.session_state.current_label = None st.session_state.current_prediction = None st.session_state.label_input = None st.session_state.started=False st.session_state.which_pic=1 st.session_state.calculating=True st.session_state.next=True # Button to fetch a new image col1, col2, col3, col4 = st.columns([1,1,1,1]) with col1: if st.button("Start",disabled=st.session_state.started): if not st.session_state.started: fetch_image() st.session_state.started=True st.session_state.calculating=False # Display the current image with col2: if st.button("cat",disabled=st.session_state.calculating): print("cat pressed") st.session_state.label_input="cat" with col3: if st.button("dog",disabled=st.session_state.calculating): print("dog pressed") st.session_state.label_input="dog" if st.session_state.current_image is not None: print(f"SHAPE:{st.session_state.current_image.shape}") prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0] st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat' st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)") st.image(st.session_state.unprocessed_image) if st.session_state.label_input in ['cat', 'dog']: label_input=st.session_state.label_input st.session_state.label_input="None" # Convert user input to 0 (cat) or 1 (dog) print(f"LABEL CLICKED IS: {label_input.lower()}") label = 0 if label_input.lower() == 'cat' else 1 st.session_state.current_label = label # Add the labeled image and label to training data st.session_state.training_data.append((st.session_state.current_image, label)) # Retrain the model image = np.array([img for img, _ in st.session_state.training_data]) label = np.array([lab for _, lab in st.session_state.training_data]) # Predict the current image st.session_state.current_image=None print("before model fit") def model_fit(): print("Entering model fit function") st.session_state.model.fit(image, label, epochs=1) st.write(st.session_state.model.evaluate(image, label, verbose=2)) st.session_state.calculating=True return model_fit() print("after model fit") #st.session_state.unprocessed_image=None print("before fetch_image") #fetch_image() print("after fetch_image") st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}") st.session_state.which_pic=st.session_state.which_pic+1 st.session_state.next=False fetch_image() with col4: if st.button("next",disabled=st.session_state.next): if not st.session_state.next: st.session_state.next=True st.session_state.calculating=False st.rerun()