trohith89 commited on
Commit
3c8cf5f
·
verified ·
1 Parent(s): b83b78f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -25
app.py CHANGED
@@ -14,14 +14,26 @@ st.set_page_config(
14
  initial_sidebar_state="expanded"
15
  )
16
 
17
- # Load the trained model and label encoder
18
  @st.cache_resource
19
  def load_resources():
20
- model = load_model("captains_cv2_model.keras")
 
 
 
 
 
 
 
 
 
 
 
21
  with open("label_encoder.pkl", "rb") as f:
22
  le = pickle.load(f)
23
  return model, le
24
 
 
25
  model, label_encoder = load_resources()
26
 
27
  # Function to preprocess the uploaded image
@@ -33,9 +45,16 @@ def preprocess_image(uploaded_file):
33
 
34
  # Read the image using cv2.imread
35
  img = cv2.imread(temp_path)
 
 
 
36
  # Resize to the model's expected input size (64, 64)
37
- img = cv2.resize(img, (64, 64)) # Note: cv2.resize takes (width, height), not (height, width, channels)
38
- # Add new axis for batch dimension
 
 
 
 
39
  img = img[np.newaxis, :, :, :]
40
 
41
  # Clean up the temporary file
@@ -61,31 +80,35 @@ st.markdown("Upload an image below, and let the model predict its class!")
61
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
62
 
63
  if uploaded_file is not None:
64
- # Display the uploaded image
65
- image = Image.open(uploaded_file)
66
- uploaded_file.seek(0) # Reset file pointer after reading for display
67
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
68
 
69
- # Preprocess the image
70
- processed_image = preprocess_image(uploaded_file)
 
 
 
71
 
72
- # Make prediction
73
- with st.spinner("Predicting..."):
74
- # Predict and decode as per your specified steps
75
- prediction = model.predict(processed_image)
76
- predicted_class_idx = np.argmax(prediction, axis=1)[0]
77
- predicted_class = label_encoder.inverse_transform([predicted_class_idx])[0]
78
 
79
- # Display the prediction
80
- st.success("Prediction complete!")
81
- st.markdown(f"### Predicted Class: **{predicted_class}**")
82
- st.write(f"Prediction Confidence: {prediction[0][predicted_class_idx]:.4f}")
 
83
 
84
- # Optional: Display confidence scores for all classes
85
- if st.checkbox("Show confidence scores for all classes"):
86
- class_names = label_encoder.classes_
87
- confidence_scores = {class_names[i]: float(prediction[0][i]) for i in range(len(class_names))}
88
- st.bar_chart(confidence_scores)
89
 
90
  else:
91
  st.info("Please upload an image to get started.")
 
14
  initial_sidebar_state="expanded"
15
  )
16
 
17
+ # Load the trained model and label encoder with error handling
18
  @st.cache_resource
19
  def load_resources():
20
+ try:
21
+ # Load the model (assuming TensorFlow 2.6+ with batch_shape support)
22
+ model = load_model("captains_cv2_model.keras")
23
+ except TypeError as e:
24
+ # Fallback for compatibility issues
25
+ st.error(f"Model loading failed: {e}")
26
+ st.warning("Attempting to load model without compilation...")
27
+ model = load_model("captains_cv2_model.keras", compile=False)
28
+ # Recompile the model manually if needed
29
+ model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
30
+
31
+ # Load the label encoder
32
  with open("label_encoder.pkl", "rb") as f:
33
  le = pickle.load(f)
34
  return model, le
35
 
36
+ # Load resources
37
  model, label_encoder = load_resources()
38
 
39
  # Function to preprocess the uploaded image
 
45
 
46
  # Read the image using cv2.imread
47
  img = cv2.imread(temp_path)
48
+ if img is None:
49
+ raise ValueError("Failed to load image. Please ensure the file is a valid image.")
50
+
51
  # Resize to the model's expected input size (64, 64)
52
+ img = cv2.resize(img, (64, 64)) # cv2 uses (width, height)
53
+ # Convert BGR (OpenCV default) to RGB if needed
54
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
55
+ # Normalize pixel values to [0, 1] (common for CNNs)
56
+ img = img / 255.0
57
+ # Add batch dimension
58
  img = img[np.newaxis, :, :, :]
59
 
60
  # Clean up the temporary file
 
80
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
81
 
82
  if uploaded_file is not None:
83
+ try:
84
+ # Display the uploaded image
85
+ image = Image.open(uploaded_file)
86
+ uploaded_file.seek(0) # Reset file pointer after reading for display
87
+ st.image(image, caption="Uploaded Image", use_column_width=True)
88
+
89
+ # Preprocess the image
90
+ processed_image = preprocess_image(uploaded_file)
91
 
92
+ # Make prediction
93
+ with st.spinner("Predicting..."):
94
+ prediction = model.predict(processed_image)
95
+ predicted_class_idx = np.argmax(prediction, axis=1)[0]
96
+ predicted_class = label_encoder.inverse_transform([predicted_class_idx])[0]
97
 
98
+ # Display the prediction
99
+ st.success("Prediction complete!")
100
+ st.markdown(f"### Predicted Class: **{predicted_class}**")
101
+ st.write(f"Prediction Confidence: {prediction[0][predicted_class_idx]:.4f}")
 
 
102
 
103
+ # Optional: Display confidence scores for all classes
104
+ if st.checkbox("Show confidence scores for all classes"):
105
+ class_names = label_encoder.classes_
106
+ confidence_scores = {class_names[i]: float(prediction[0][i]) for i in range(len(class_names))}
107
+ st.bar_chart(confidence_scores)
108
 
109
+ except Exception as e:
110
+ st.error(f"An error occurred: {e}")
111
+ st.info("Please try uploading a different image or check the model compatibility.")
 
 
112
 
113
  else:
114
  st.info("Please upload an image to get started.")