Z091 commited on
Commit
c511886
·
verified ·
1 Parent(s): 9ca3d11

Create streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +128 -0
streamlit_app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing import image
6
+ # Import preprocess_input specifically from the efficientnet application module
7
+ from tensorflow.keras.applications.efficientnet import preprocess_input
8
+ import os
9
+ from PIL import Image # Needed to display image in Streamlit
10
+
11
+ # --- Configuration ---
12
+ # Ensure this matches the IMG_SIZE used during training
13
+ IMG_SIZE = 260
14
+ # Define the expected model filename
15
+ MODEL_FILENAME = 'skin_lesion_model.keras'
16
+ # Define class names based on the training script output
17
+ CLASS_NAMES = ['Benign', 'Malignant'] # From training: ['benign', 'malignant']
18
+
19
+ # --- Model Loading (Cached) ---
20
+ @st.cache_resource # Decorator to cache the model loading
21
+ def load_skin_model():
22
+ """Loads the Keras model. Returns the model or None if not found."""
23
+ if not os.path.exists(MODEL_FILENAME):
24
+ st.error(f"Error: Model file '{MODEL_FILENAME}' not found.")
25
+ st.info(f"Please ensure the model file is in the same directory as the script.")
26
+ return None
27
+ try:
28
+ # Load the model, compile=False speeds up loading for inference only
29
+ model = load_model(MODEL_FILENAME, compile=False)
30
+ print("Model loaded successfully.") # Log for server console
31
+ return model
32
+ except Exception as e:
33
+ st.error(f"Error loading model: {e}")
34
+ print(f"Error loading model: {e}") # Log for server console
35
+ return None
36
+
37
+ # --- Preprocessing Function ---
38
+ def preprocess_image(img_input):
39
+ """Loads and preprocesses an image for EfficientNetB0."""
40
+ try:
41
+ # Load image directly from uploaded file object or path
42
+ # Use PIL to open the image from the BytesIO object provided by file_uploader
43
+ img = Image.open(img_input).convert('RGB') # Ensure image is RGB
44
+ img = img.resize((IMG_SIZE, IMG_SIZE))
45
+
46
+ img_array = image.img_to_array(img)
47
+ img_array = np.expand_dims(img_array, axis=0)
48
+ # Use the appropriate preprocessing function for EfficientNet
49
+ processed_img = preprocess_input(img_array)
50
+ print(f"Image preprocessed successfully. Shape: {processed_img.shape}") # Debug print
51
+ return processed_img
52
+ except Exception as e:
53
+ st.error(f"Error during image preprocessing: {e}")
54
+ print(f"Error during image preprocessing: {e}") # Log for server console
55
+ return None # Return None to indicate failure
56
+
57
+ # --- Prediction Function ---
58
+ def predict_skin_lesion(model, processed_image):
59
+ """Makes predictions using the loaded model and preprocessed image."""
60
+ try:
61
+ # Make prediction
62
+ print("Making prediction...") # Debug print
63
+ prediction = model.predict(processed_image)[0]
64
+ print(f"Raw prediction output: {prediction}") # Debug print
65
+
66
+ # Get the class with highest probability
67
+ class_index = np.argmax(prediction)
68
+ confidence = float(prediction[class_index])
69
+
70
+ # Map class index to label using CLASS_NAMES
71
+ class_label = CLASS_NAMES[class_index]
72
+
73
+ print(f"Predicted class: {class_label}, Confidence: {confidence:.4f}") # Debug print
74
+ return class_label, confidence
75
+
76
+ except Exception as e:
77
+ st.error(f"An error occurred during prediction: {e}")
78
+ print(f"An error occurred during prediction: {e}") # Log for server console
79
+ return None, None # Return None to indicate failure
80
+
81
+ # --- Streamlit App UI ---
82
+ st.set_page_config(page_title="Skin Lesion Classifier", layout="centered")
83
+ st.title("Skin Lesion Classification (EfficientNetB0)")
84
+ st.markdown(f"Upload an image of a skin lesion to classify it as benign or malignant. Model trained on {IMG_SIZE}x{IMG_SIZE} images.")
85
+
86
+ # Load the model using the cached function
87
+ model = load_skin_model()
88
+
89
+ # Only proceed if the model loaded successfully
90
+ if model is not None:
91
+ # File uploader
92
+ uploaded_file = st.file_uploader("Choose a skin lesion image...", type=["jpg", "jpeg", "png"])
93
+
94
+ if uploaded_file is not None:
95
+ # Display the uploaded image
96
+ st.image(uploaded_file, caption='Uploaded Image.', use_column_width=True)
97
+ st.write("") # Add a little space
98
+
99
+ # Classify button
100
+ if st.button('Classify Lesion'):
101
+ with st.spinner('Preprocessing image and making prediction...'):
102
+ # Preprocess the image
103
+ processed_image = preprocess_image(uploaded_file)
104
+
105
+ if processed_image is not None:
106
+ # Make prediction
107
+ label, confidence = predict_skin_lesion(model, processed_image)
108
+
109
+ if label is not None:
110
+ # Display result
111
+ st.success(f'Prediction: **{label}**')
112
+ st.metric(label="Confidence", value=f"{confidence:.2%}")
113
+ # Optional: Display confidence breakdown
114
+ # st.write("Confidence Scores:")
115
+ # st.write({name: f"{pred:.2%}" for name, pred in zip(CLASS_NAMES, model.predict(processed_image)[0])})
116
+ else:
117
+ st.error("Prediction failed. Please check the logs or try a different image.")
118
+ else:
119
+ st.error("Image preprocessing failed. Please ensure the image is valid.")
120
+ else:
121
+ # Message if model loading failed (already handled in load_skin_model, but good practice)
122
+ st.warning("Model could not be loaded. Please check the setup.")
123
+
124
+ # --- How to Run ---
125
+ # Save this code as a Python file (e.g., app.py)
126
+ # Ensure 'skin_lesion_model.keras' is in the same directory.
127
+ # Install libraries: pip install streamlit numpy tensorflow Pillow
128
+ # Run from terminal: streamlit run app.py