DeBillFriends / streamlit_app.py
Z091's picture
Update streamlit_app.py
3eed454 verified
import streamlit as st
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
# Import preprocess_input specifically from the efficientnet application module
from tensorflow.keras.applications.efficientnet import preprocess_input
import os
from PIL import Image # Needed to display image in Streamlit
import io # Needed to handle camera input which is a BytesIO object
# --- Configuration ---
IMG_SIZE = 224
# Define the expected model filename
MODEL_FILENAME = 'skin_lesion_model.keras'
# Define class names based on the training script output
CLASS_NAMES = ['Benign', 'Malignant'] # From training: ['benign', 'malignant']
# --- Model Loading (Cached) ---
@st.cache_resource # Decorator to cache the model loading
def load_skin_model():
"""Loads the Keras model. Returns the model or None if not found."""
if not os.path.exists(MODEL_FILENAME):
st.error(f"Error: Model file '{MODEL_FILENAME}' not found.")
st.info(f"Please ensure the model file is in the same directory as the script.")
return None
try:
# Load the model, compile=False speeds up loading for inference only
model = load_model(MODEL_FILENAME, compile=False)
print("Model loaded successfully.") # Log for server console
return model
except Exception as e:
st.error(f"Error loading model: {e}")
print(f"Error loading model: {e}") # Log for server console
return None
# --- Preprocessing Function ---
def preprocess_image(img_input):
"""Loads and preprocesses an image for EfficientNetB0."""
try:
# Load image directly from uploaded file object, camera input, or path
# Use PIL to open the image from the BytesIO object provided by file_uploader/camera_input
img = Image.open(img_input).convert('RGB') # Ensure image is RGB
img = img.resize((IMG_SIZE, IMG_SIZE))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
# Use the appropriate preprocessing function for EfficientNet
processed_img = preprocess_input(img_array)
print(f"Image preprocessed successfully. Shape: {processed_img.shape}") # Debug print
return processed_img
except Exception as e:
st.error(f"Error during image preprocessing: {e}")
print(f"Error during image preprocessing: {e}") # Log for server console
return None # Return None to indicate failure
# --- Prediction Function ---
def predict_skin_lesion(model, processed_image):
"""Makes predictions using the loaded model and preprocessed image."""
try:
# Make prediction
print("Making prediction...") # Debug print
prediction = model.predict(processed_image)[0]
print(f"Raw prediction output: {prediction}") # Debug print
# Get the class with highest probability
class_index = np.argmax(prediction)
confidence = float(prediction[class_index])
# Map class index to label using CLASS_NAMES
class_label = CLASS_NAMES[class_index]
print(f"Predicted class: {class_label}, Confidence: {confidence:.4f}") # Debug print
return class_label, confidence
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
print(f"An error occurred during prediction: {e}") # Log for server console
return None, None # Return None to indicate failure
# --- Streamlit App UI ---
st.set_page_config(page_title="Skin Lesion Classifier", layout="centered")
st.title("Skin Lesion Classification (EfficientNetB0)")
st.markdown(f"Upload an image of a skin lesion or use your camera for a potential classification (benign or malignant). Important: This tool uses a probabilistic model and its output is not a substitute for a professional medical diagnosis. Always consult your physician.")
# Load the model using the cached function
model = load_skin_model()
# Only proceed if the model loaded successfully
if model is not None:
# Option 1: File uploader
uploaded_file = st.file_uploader("Choose a skin lesion image...", type=["jpg", "jpeg", "png"])
# Option 2: Camera input
camera_input = st.camera_input("Or take a picture using your camera:")
# Determine the source of the image
image_source = None
source_type = None # To adjust the caption later
if uploaded_file is not None:
image_source = uploaded_file
source_type = "Uploaded"
# Clear camera input if a file is uploaded after taking a picture
if camera_input is not None:
st.warning("You provided both an uploaded file and a camera shot. Using the uploaded file.")
camera_input = None # Prioritize uploaded file
elif camera_input is not None:
image_source = camera_input
source_type = "Camera"
# Process the image if one was provided
if image_source is not None:
# Display the image
st.image(image_source, caption=f'{source_type} Image.', use_column_width=True)
st.write("") # Add a little space
# Classify button
if st.button('Classify Lesion'):
with st.spinner('Preprocessing image and making prediction...'):
# Preprocess the image (works for both file uploader and camera input)
processed_image = preprocess_image(image_source)
if processed_image is not None:
# Make prediction
label, confidence = predict_skin_lesion(model, processed_image)
if label is not None:
# Display result
st.success(f'Prediction: **{label}**')
st.metric(label="Confidence", value=f"{confidence:.2%}")
else:
st.error("Prediction failed. Please check the logs or try a different image.")
else:
st.error("Image preprocessing failed. Please ensure the image is valid.")
else:
st.warning("Model could not be loaded. Please check the setup.")