Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Sequential | |
| from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense | |
| tf.config.set_visible_devices([], 'GPU') | |
| # --------------------------- | |
| # Helper Functions | |
| # --------------------------- | |
| def fetch_cat_image(): | |
| """Fetches a random cat image from the CATAAS API.""" | |
| response = requests.get('https://cataas.com/cat?json=true') | |
| print(f"STATUS_CODE:{response.status_code}") | |
| if response.status_code == 200: | |
| data = response.json() | |
| url = data['url'] | |
| response_image = requests.get(url, stream=True) | |
| if response_image.status_code == 200: | |
| print(f"SECOND STATUS_CODE:{response_image.status_code}") | |
| return Image.open(response_image.raw) | |
| else: | |
| return None | |
| def fetch_dog_image(): | |
| """Fetches a random dog image from the Dog CEO API.""" | |
| response = requests.get('https://dog.ceo/api/breeds/image/random') | |
| print(f"STATUS_CODE:{response.status_code}") | |
| if response.status_code == 200: | |
| data = response.json() | |
| if data['status'] == 'success': | |
| url = data['message'] | |
| response_image = requests.get(url, stream=True) | |
| if response_image.status_code == 200: | |
| print(f"SECOND STATUS_CODE:{response_image.status_code}") | |
| return Image.open(response_image.raw) | |
| return None | |
| def fetch_random_image(): | |
| """Randomly fetches either a cat or a dog image.""" | |
| if np.random.rand() < 0.5: | |
| return fetch_cat_image() | |
| else: | |
| return fetch_dog_image() | |
| def preprocess_image(image, target_size=(64, 64)): | |
| """Resizes, normalizes, and ensures 3 channels for the image.""" | |
| try: | |
| image = image.resize(target_size) | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| elif image.mode == 'P': | |
| image = image.convert('RGB') | |
| image = np.array(image) / 255.0 # Normalize to [0, 1] | |
| if len(image.shape) == 2: # Convert grayscale to RGB | |
| image = np.stack([image] * 3, axis=-1) | |
| except Exception as e: | |
| print(f"Preprocessing error: {e}") | |
| return None | |
| return image | |
| # --------------------------- | |
| # Model Creation | |
| # --------------------------- | |
| def create_model(input_shape=(64, 64, 3)): | |
| """Creates a simple neural network with optional CNN layers.""" | |
| model = Sequential([ | |
| tf.keras.layers.Input(shape=input_shape), | |
| Conv2D(32, (3, 3), activation='relu'), | |
| MaxPooling2D((2, 2)), | |
| Conv2D(64, (3, 3), activation='relu'), | |
| MaxPooling2D((2, 2)), | |
| Flatten(), | |
| Dense(64, activation='relu'), | |
| Dense(1, activation='sigmoid') | |
| ]) | |
| model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) | |
| return model | |
| def fetch_image(): | |
| print("entering fetch_image") | |
| try: | |
| image = fetch_random_image() | |
| if image is not None: | |
| st.session_state.unprocessed_image = image | |
| st.session_state.current_image = preprocess_image(image) | |
| st.session_state.current_label = None | |
| st.session_state.current_prediction = None | |
| st.info("image ready") | |
| else: | |
| st.error("Failed to fetch an image. Please try again.") | |
| except Exception as e: | |
| st.error(f"An error occurred while fetching the image: {e}") | |
| # --------------------------- | |
| # Streamlit UI | |
| # --------------------------- | |
| st.title("Neural Network Classification Demo") | |
| placeholder = st.empty() | |
| # Initialize session state | |
| if 'model' not in st.session_state: | |
| st.session_state.model = create_model() | |
| st.session_state.training_data = [] | |
| st.session_state.current_image = None | |
| st.session_state.current_label = None | |
| st.session_state.current_prediction = None | |
| st.session_state.label_input = None | |
| st.session_state.started=False | |
| st.session_state.which_pic=1 | |
| st.session_state.calculating=True | |
| st.session_state.next=True | |
| # Button to fetch a new image | |
| col1, col2, col3, col4 = st.columns([1,1,1,1]) | |
| with col1: | |
| if st.button("Start",disabled=st.session_state.started): | |
| if not st.session_state.started: | |
| fetch_image() | |
| st.session_state.started=True | |
| st.session_state.calculating=False | |
| # Display the current image | |
| with col2: | |
| if st.button("cat",disabled=st.session_state.calculating): | |
| print("cat pressed") | |
| st.session_state.label_input="cat" | |
| with col3: | |
| if st.button("dog",disabled=st.session_state.calculating): | |
| print("dog pressed") | |
| st.session_state.label_input="dog" | |
| if st.session_state.current_image is not None: | |
| print(f"SHAPE:{st.session_state.current_image.shape}") | |
| prediction = st.session_state.model.predict(np.array([st.session_state.current_image]))[0][0] | |
| st.session_state.current_prediction = 'dog' if prediction > 0.5 else 'cat' | |
| st.success(f"**Model Predicts:** {st.session_state.current_prediction} --- (cat-confidence {(1-prediction)*100:.2f}%; dog-confidence {(prediction)*100:.2f}%)") | |
| st.image(st.session_state.unprocessed_image) | |
| if st.session_state.label_input in ['cat', 'dog']: | |
| label_input=st.session_state.label_input | |
| st.session_state.label_input="None" | |
| # Convert user input to 0 (cat) or 1 (dog) | |
| print(f"LABEL CLICKED IS: {label_input.lower()}") | |
| label = 0 if label_input.lower() == 'cat' else 1 | |
| st.session_state.current_label = label | |
| # Add the labeled image and label to training data | |
| st.session_state.training_data.append((st.session_state.current_image, label)) | |
| # Retrain the model | |
| image = np.array([img for img, _ in st.session_state.training_data]) | |
| label = np.array([lab for _, lab in st.session_state.training_data]) | |
| # Predict the current image | |
| st.session_state.current_image=None | |
| print("before model fit") | |
| def model_fit(): | |
| print("Entering model fit function") | |
| st.session_state.model.fit(image, label, epochs=1) | |
| st.write(st.session_state.model.evaluate(image, label, verbose=2)) | |
| st.session_state.calculating=True | |
| return | |
| model_fit() | |
| print("after model fit") | |
| #st.session_state.unprocessed_image=None | |
| print("before fetch_image") | |
| #fetch_image() | |
| print("after fetch_image") | |
| st.info(f"You clicked on last picture (picture {st.session_state.which_pic}): {label_input}") | |
| st.session_state.which_pic=st.session_state.which_pic+1 | |
| st.session_state.next=False | |
| fetch_image() | |
| with col4: | |
| if st.button("next",disabled=st.session_state.next): | |
| if not st.session_state.next: | |
| st.session_state.next=True | |
| st.session_state.calculating=False | |
| st.rerun() | |