trohith89 commited on
Commit
7bc0302
·
verified ·
1 Parent(s): 3c8cf5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -101
app.py CHANGED
@@ -1,118 +1,87 @@
1
  import streamlit as st
2
- import numpy as np
3
  import cv2
 
4
  from tensorflow.keras.models import load_model
5
  import pickle
6
  from PIL import Image
7
  import os
8
 
9
- # Set page configuration
10
- st.set_page_config(
11
- page_title="Image Detection App",
12
- page_icon="📸",
13
- layout="centered",
14
- initial_sidebar_state="expanded"
15
- )
16
-
17
- # Load the trained model and label encoder with error handling
18
  @st.cache_resource
19
  def load_resources():
20
- try:
21
- # Load the model (assuming TensorFlow 2.6+ with batch_shape support)
22
- model = load_model("captains_cv2_model.keras")
23
- except TypeError as e:
24
- # Fallback for compatibility issues
25
- st.error(f"Model loading failed: {e}")
26
- st.warning("Attempting to load model without compilation...")
27
- model = load_model("captains_cv2_model.keras", compile=False)
28
- # Recompile the model manually if needed
29
- model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
30
-
31
- # Load the label encoder
32
- with open("label_encoder.pkl", "rb") as f:
33
- le = pickle.load(f)
34
  return model, le
35
 
36
- # Load resources
37
- model, label_encoder = load_resources()
38
-
39
- # Function to preprocess the uploaded image
40
- def preprocess_image(uploaded_file):
41
- # Save the uploaded file temporarily to disk
42
- temp_path = "temp_image.jpg"
43
- with open(temp_path, "wb") as f:
44
- f.write(uploaded_file.read())
45
 
46
- # Read the image using cv2.imread
47
- img = cv2.imread(temp_path)
48
- if img is None:
49
- raise ValueError("Failed to load image. Please ensure the file is a valid image.")
50
 
51
- # Resize to the model's expected input size (64, 64)
52
- img = cv2.resize(img, (64, 64)) # cv2 uses (width, height)
53
- # Convert BGR (OpenCV default) to RGB if needed
54
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
55
- # Normalize pixel values to [0, 1] (common for CNNs)
56
- img = img / 255.0
57
- # Add batch dimension
58
- img = img[np.newaxis, :, :, :]
59
 
60
- # Clean up the temporary file
61
- os.remove(temp_path)
62
- return img
63
 
64
- # Sidebar
65
- st.sidebar.title("About")
66
- st.sidebar.info(
67
- "This app uses a Convolutional Neural Network (CNN) to classify images into one of 10 categories. "
68
- "Upload an image, and the model will predict its class!"
69
- )
70
- st.sidebar.markdown("### Classes")
71
- st.sidebar.write(
72
- "The model can predict: lifeboat, ladybug, pizza, bell pepper, school bus, koala, espresso, red panda, orange, sports car."
73
- )
74
-
75
- # Main content
76
- st.title("📸 Image Classification App")
77
- st.markdown("Upload an image below, and let the model predict its class!")
78
-
79
- # File uploader
80
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
81
-
82
- if uploaded_file is not None:
83
- try:
84
- # Display the uploaded image
85
  image = Image.open(uploaded_file)
86
- uploaded_file.seek(0) # Reset file pointer after reading for display
87
- st.image(image, caption="Uploaded Image", use_column_width=True)
88
-
89
- # Preprocess the image
90
- processed_image = preprocess_image(uploaded_file)
91
-
92
- # Make prediction
93
- with st.spinner("Predicting..."):
94
- prediction = model.predict(processed_image)
95
- predicted_class_idx = np.argmax(prediction, axis=1)[0]
96
- predicted_class = label_encoder.inverse_transform([predicted_class_idx])[0]
97
-
98
- # Display the prediction
99
- st.success("Prediction complete!")
100
- st.markdown(f"### Predicted Class: **{predicted_class}**")
101
- st.write(f"Prediction Confidence: {prediction[0][predicted_class_idx]:.4f}")
102
-
103
- # Optional: Display confidence scores for all classes
104
- if st.checkbox("Show confidence scores for all classes"):
105
- class_names = label_encoder.classes_
106
- confidence_scores = {class_names[i]: float(prediction[0][i]) for i in range(len(class_names))}
107
- st.bar_chart(confidence_scores)
108
-
109
- except Exception as e:
110
- st.error(f"An error occurred: {e}")
111
- st.info("Please try uploading a different image or check the model compatibility.")
112
-
113
- else:
114
- st.info("Please upload an image to get started.")
 
 
 
 
 
 
115
 
116
- # Footer
117
- st.markdown("---")
118
- st.markdown("Created with ❤️ using Streamlit | Hosted on [Hugging Face Spaces](https://huggingface.co/spaces)")
 
1
  import streamlit as st
 
2
  import cv2
3
+ import numpy as np
4
  from tensorflow.keras.models import load_model
5
  import pickle
6
  from PIL import Image
7
  import os
8
 
9
+ # Load the model and label encoder
 
 
 
 
 
 
 
 
10
  @st.cache_resource
11
  def load_resources():
12
+ model = load_model('captains_cv2_model.keras')
13
+ with open('label_encoder.pkl', 'rb') as file:
14
+ le = pickle.load(file)
 
 
 
 
 
 
 
 
 
 
 
15
  return model, le
16
 
17
+ # Preprocess the image
18
+ def preprocess_image(image_path):
19
+ # Read and convert image
20
+ img1 = cv2.imread(image_path)
21
+ img1 = cv2.resize(img1, (64, 64)) # Resize to 64x64
22
+ img1 = np.asarray(img1) # Convert to numpy array, shape will be (64, 64, 3)
 
 
 
23
 
24
+ # Add batch dimension to get (1, 64, 64, 3)
25
+ img1 = img1[np.newaxis, :, :, :]
 
 
26
 
27
+ # Verify shape
28
+ if len(img1.shape) != 4:
29
+ raise ValueError(f"Expected shape (batch_size, 64, 64, 3), got {img1.shape}")
30
+ if img1.shape[1:] != (64, 64, 3):
31
+ raise ValueError(f"Image dimensions should be (64, 64, 3), got {img1.shape[1:]}")
 
 
 
32
 
33
+ return img1
 
 
34
 
35
+ # Main app
36
+ def main():
37
+ # Load resources
38
+ model, le = load_resources()
39
+
40
+ # Streamlit UI
41
+ st.title("Image Classification App")
42
+ st.write("Upload an image to get a prediction")
43
+
44
+ # File uploader
45
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
46
+
47
+ if uploaded_file is not None:
48
+ # Display uploaded image
 
 
 
 
 
 
 
49
  image = Image.open(uploaded_file)
50
+ st.image(image, caption='Uploaded Image', use_column_width=True)
51
+
52
+ # Get original file extension
53
+ file_extension = os.path.splitext(uploaded_file.name)[1].lower()
54
+ temp_filename = f"temp_image{file_extension}"
55
+
56
+ # Save temporary file with original extension
57
+ with open(temp_filename, "wb") as f:
58
+ f.write(uploaded_file.getvalue())
59
+
60
+ try:
61
+ # Preprocess image
62
+ processed_img = preprocess_image(temp_filename)
63
+
64
+ # Display shape for debugging
65
+ st.write(f"Processed image shape: {processed_img.shape}")
66
+
67
+ # Make prediction
68
+ prediction = model.predict(processed_img)
69
+ predicted_class = le.inverse_transform([np.argmax(prediction)])
70
+
71
+ # Display prediction
72
+ st.write("Prediction:", predicted_class[0])
73
+
74
+ # Display prediction probabilities
75
+ st.write("Prediction Probabilities:")
76
+ for class_name, prob in zip(le.classes_, prediction[0]):
77
+ st.write(f"{class_name}: {prob:.4f}")
78
+
79
+ except Exception as e:
80
+ st.error(f"An error occurred: {str(e)}")
81
+
82
+ # Clean up temporary file
83
+ if os.path.exists(temp_filename):
84
+ os.remove(temp_filename)
85
 
86
+ if __name__ == '__main__':
87
+ main()