Nomi78600 commited on
Commit
485d30a
·
1 Parent(s): decceac
Files changed (2) hide show
  1. app.py +48 -49
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import streamlit as st
3
  import tensorflow as tf
4
  from tensorflow.keras.models import load_model
@@ -6,7 +5,6 @@ from PIL import Image
6
  import numpy as np
7
  import requests
8
  import os
9
- from tqdm import tqdm
10
 
11
  # Set page config
12
  st.set_page_config(
@@ -37,33 +35,33 @@ st.markdown(
37
  )
38
 
39
  # --- Model Downloading and Loading ---
40
- def download_file_from_google_drive(id, destination):
41
  URL = f'https://drive.google.com/uc?export=download&id={id}'
42
  session = requests.Session()
 
43
  response = session.get(URL, stream=True)
44
-
45
  token = None
46
  for key, value in response.cookies.items():
47
  if key.startswith('download_warning'):
48
  token = value
49
- break
50
-
51
  if token:
52
  params = {'id': id, 'confirm': token}
53
  response = session.get(URL, params=params, stream=True)
54
-
55
  total_size = int(response.headers.get('content-length', 0))
56
  block_size = 1024 # 1 Kibibyte
57
-
58
- progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
59
  with open(destination, 'wb') as f:
60
  for data in response.iter_content(block_size):
61
- progress_bar.update(len(data))
62
  f.write(data)
63
- progress_bar.close()
64
-
65
- if total_size != 0 and progress_bar.n != total_size:
66
- st.error("An error occurred during file download.")
 
67
  return False
68
  return True
69
 
@@ -74,19 +72,26 @@ def load_keras_model():
74
  The `st.cache_resource` decorator ensures the model is loaded only once.
75
  """
76
  MODEL_PATH = "my_model.keras"
77
- FILE_ID = "1M-HNEJqbz6PzjhX6WHHKLPbjZpPRWLjP" # Replace with your file ID
78
 
79
  if not os.path.exists(MODEL_PATH):
80
  st.info("Model not found locally. Downloading from Google Drive... (this may take a moment)")
81
- download_file_from_google_drive(FILE_ID, MODEL_PATH)
82
- st.success("Model downloaded successfully!")
 
 
 
 
 
 
 
83
 
84
  try:
85
  model = load_model(MODEL_PATH)
86
  return model
87
  except Exception as e:
88
  st.error(f"Error loading model: {e}")
89
- st.info("Please ensure the Google Drive File ID is correct and the file is accessible.")
90
  return None
91
 
92
  model = load_keras_model()
@@ -95,15 +100,11 @@ model = load_keras_model()
95
  def preprocess_image(image):
96
  """
97
  Preprocesses the uploaded image to fit the model's input requirements.
98
- - Resizes to (256, 256)
99
- - Converts to a NumPy array
100
- - Normalizes pixel values
101
- - Expands dimensions for the model
102
  """
103
  img = image.resize((256, 256))
104
  img_array = np.array(img)
105
  img_array = img_array / 255.0
106
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
107
  return img_array
108
 
109
  # --- UI Layout ---
@@ -120,33 +121,31 @@ with col1:
120
  "Choose an image...", type=["jpg", "jpeg", "png"]
121
  )
122
 
123
- if uploaded_file is not None and model is not None:
124
- # Display the uploaded image
125
  image = Image.open(uploaded_file)
126
-
127
  with col1:
128
  st.image(image, caption="Uploaded Image", use_column_width=True)
129
-
130
- # Preprocess the image and make a prediction
131
- processed_image = preprocess_image(image)
132
- prediction = model.predict(processed_image)
133
- confidence = prediction[0][0]
134
-
135
- with col2:
136
- st.header("Prediction")
137
- if confidence > 0.5:
138
- st.markdown(
139
- f"## This is a Dog! 🐶"
140
- )
141
- st.progress(confidence)
142
- st.write(f"**Confidence:** {confidence:.2f}")
143
- else:
144
- st.markdown(
145
- f"## This is a Cat! 🐱"
146
- )
147
- st.progress(1-confidence)
148
- st.write(f"**Confidence:** {1-confidence:.2f}")
149
  else:
150
- with col2:
151
- st.info("Please upload an image to see the prediction.")
152
-
 
 
1
  import streamlit as st
2
  import tensorflow as tf
3
  from tensorflow.keras.models import load_model
 
5
  import numpy as np
6
  import requests
7
  import os
 
8
 
9
  # Set page config
10
  st.set_page_config(
 
35
  )
36
 
37
  # --- Model Downloading and Loading ---
38
+ def download_file_from_google_drive(id, destination, progress_bar):
39
  URL = f'https://drive.google.com/uc?export=download&id={id}'
40
  session = requests.Session()
41
+
42
  response = session.get(URL, stream=True)
 
43
  token = None
44
  for key, value in response.cookies.items():
45
  if key.startswith('download_warning'):
46
  token = value
47
+
 
48
  if token:
49
  params = {'id': id, 'confirm': token}
50
  response = session.get(URL, params=params, stream=True)
51
+
52
  total_size = int(response.headers.get('content-length', 0))
53
  block_size = 1024 # 1 Kibibyte
54
+ current_size = 0
55
+
56
  with open(destination, 'wb') as f:
57
  for data in response.iter_content(block_size):
58
+ current_size += len(data)
59
  f.write(data)
60
+ # Update Streamlit progress bar
61
+ progress_percentage = min(int((current_size / total_size) * 100), 100)
62
+ progress_bar.progress(progress_percentage, text=f"Downloading... {current_size // (1024*1024)}MB / {total_size // (1024*1024)}MB")
63
+
64
+ if total_size != 0 and current_size != total_size:
65
  return False
66
  return True
67
 
 
72
  The `st.cache_resource` decorator ensures the model is loaded only once.
73
  """
74
  MODEL_PATH = "my_model.keras"
75
+ FILE_ID = "1M-HNEJqbz6PzjhX6WHHKLPbjZpPRWLjP"
76
 
77
  if not os.path.exists(MODEL_PATH):
78
  st.info("Model not found locally. Downloading from Google Drive... (this may take a moment)")
79
+ progress_bar = st.progress(0, text="Starting download...")
80
+
81
+ download_successful = download_file_from_google_drive(FILE_ID, MODEL_PATH, progress_bar)
82
+
83
+ progress_bar.empty() # Clear the progress bar after completion
84
+
85
+ if not download_successful:
86
+ st.error("Failed to download the model. Please check the file ID and permissions on Google Drive.")
87
+ st.stop()
88
 
89
  try:
90
  model = load_model(MODEL_PATH)
91
  return model
92
  except Exception as e:
93
  st.error(f"Error loading model: {e}")
94
+ st.info("The downloaded file might be corrupted. Try deleting 'my_model.keras' and restarting the app.")
95
  return None
96
 
97
  model = load_keras_model()
 
100
  def preprocess_image(image):
101
  """
102
  Preprocesses the uploaded image to fit the model's input requirements.
 
 
 
 
103
  """
104
  img = image.resize((256, 256))
105
  img_array = np.array(img)
106
  img_array = img_array / 255.0
107
+ img_array = np.expand_dims(img_array, axis=0)
108
  return img_array
109
 
110
  # --- UI Layout ---
 
121
  "Choose an image...", type=["jpg", "jpeg", "png"]
122
  )
123
 
124
+ if uploaded_file is not None:
 
125
  image = Image.open(uploaded_file)
 
126
  with col1:
127
  st.image(image, caption="Uploaded Image", use_column_width=True)
128
+
129
+ if model is not None:
130
+ # Preprocess the image and make a prediction
131
+ processed_image = preprocess_image(image)
132
+ prediction = model.predict(processed_image)
133
+ confidence = prediction[0][0]
134
+
135
+ with col2:
136
+ st.header("Prediction")
137
+ if confidence > 0.5:
138
+ st.markdown(f"## This is a Dog! 🐶")
139
+ st.progress(float(confidence))
140
+ st.write(f"**Confidence:** {confidence:.2f}")
141
+ else:
142
+ st.markdown(f"## This is a Cat! 🐱")
143
+ st.progress(float(1-confidence))
144
+ st.write(f"**Confidence:** {1-confidence:.2f}")
145
+ else:
146
+ with col2:
147
+ st.error("Model could not be loaded. Cannot make a prediction.")
148
  else:
149
+ if model is not None:
150
+ with col2:
151
+ st.info("Please upload an image to see the prediction.")
requirements.txt CHANGED
@@ -3,4 +3,3 @@ tensorflow
3
  numpy
4
  Pillow
5
  requests
6
- tqdm
 
3
  numpy
4
  Pillow
5
  requests