trohith89 commited on
Commit
88f40ec
·
verified ·
1 Parent(s): 8f7d9c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -17
app.py CHANGED
@@ -9,59 +9,51 @@ import os
9
  # Load the model and label encoder
10
  @st.cache_resource
11
  def load_resources():
12
- model = load_model('captains_cv2_model.keras')
 
 
 
 
 
13
  with open('label_encoder.pkl', 'rb') as file:
14
  le = pickle.load(file)
15
  return model, le
16
 
17
  # Preprocess the image
18
  def preprocess_image(image_path):
19
- # Read and convert image
20
  img1 = cv2.imread(image_path)
21
  img1 = cv2.resize(img1, (64, 64)) # Resize to 64x64
22
- img1 = np.asarray(img1) # Convert to numpy array, shape will be (64, 64, 3)
 
23
  return img1
24
 
25
  # Main app
26
  def main():
27
- # Load resources
28
  model, le = load_resources()
29
 
30
- # Streamlit UI
31
  st.title("Image Classification App")
32
  st.write("Upload an image to get a prediction")
33
 
34
- # File uploader
35
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
36
 
37
  if uploaded_file is not None:
38
- # Display uploaded image
39
  image = Image.open(uploaded_file)
40
  st.image(image, caption='Uploaded Image', use_column_width=True)
41
 
42
- # Get original file extension
43
  file_extension = os.path.splitext(uploaded_file.name)[1].lower()
44
  temp_filename = f"temp_image{file_extension}"
45
 
46
- # Save temporary file with original extension
47
  with open(temp_filename, "wb") as f:
48
  f.write(uploaded_file.getvalue())
49
 
50
  try:
51
- # Preprocess image
52
  processed_img = preprocess_image(temp_filename)
53
-
54
- # Display shape for debugging
55
  st.write(f"Processed image shape: {processed_img.shape}")
56
 
57
- # Make prediction
58
  prediction = model.predict(processed_img)
59
  predicted_class = le.inverse_transform([np.argmax(prediction)])
60
 
61
- # Display prediction
62
  st.write("Prediction:", predicted_class[0])
63
-
64
- # Display prediction probabilities
65
  st.write("Prediction Probabilities:")
66
  for class_name, prob in zip(le.classes_, prediction[0]):
67
  st.write(f"{class_name}: {prob:.4f}")
@@ -69,7 +61,6 @@ def main():
69
  except Exception as e:
70
  st.error(f"An error occurred: {str(e)}")
71
 
72
- # Clean up temporary file
73
  if os.path.exists(temp_filename):
74
  os.remove(temp_filename)
75
 
 
9
  # Load the model and label encoder
10
  @st.cache_resource
11
  def load_resources():
12
+ # Custom loading to handle compatibility
13
+ try:
14
+ model = load_model('captains_cv2_model.keras', compile=False) # Load without compiling first
15
+ except Exception as e:
16
+ st.error(f"Model loading failed: {str(e)}")
17
+ raise
18
  with open('label_encoder.pkl', 'rb') as file:
19
  le = pickle.load(file)
20
  return model, le
21
 
22
  # Preprocess the image
23
  def preprocess_image(image_path):
 
24
  img1 = cv2.imread(image_path)
25
  img1 = cv2.resize(img1, (64, 64)) # Resize to 64x64
26
+ img1 = np.asarray(img1) # Shape: (64, 64, 3)
27
+ img1 = img1[np.newaxis, :, :, :] # Shape: (1, 64, 64, 3)
28
  return img1
29
 
30
  # Main app
31
  def main():
 
32
  model, le = load_resources()
33
 
 
34
  st.title("Image Classification App")
35
  st.write("Upload an image to get a prediction")
36
 
 
37
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
38
 
39
  if uploaded_file is not None:
 
40
  image = Image.open(uploaded_file)
41
  st.image(image, caption='Uploaded Image', use_column_width=True)
42
 
 
43
  file_extension = os.path.splitext(uploaded_file.name)[1].lower()
44
  temp_filename = f"temp_image{file_extension}"
45
 
 
46
  with open(temp_filename, "wb") as f:
47
  f.write(uploaded_file.getvalue())
48
 
49
  try:
 
50
  processed_img = preprocess_image(temp_filename)
 
 
51
  st.write(f"Processed image shape: {processed_img.shape}")
52
 
 
53
  prediction = model.predict(processed_img)
54
  predicted_class = le.inverse_transform([np.argmax(prediction)])
55
 
 
56
  st.write("Prediction:", predicted_class[0])
 
 
57
  st.write("Prediction Probabilities:")
58
  for class_name, prob in zip(le.classes_, prediction[0]):
59
  st.write(f"{class_name}: {prob:.4f}")
 
61
  except Exception as e:
62
  st.error(f"An error occurred: {str(e)}")
63
 
 
64
  if os.path.exists(temp_filename):
65
  os.remove(temp_filename)
66