VarunRavichander commited on
Commit
4de5386
·
verified ·
1 Parent(s): da2a982

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -63
app.py CHANGED
@@ -1136,89 +1136,61 @@ if 'processed_images' not in st.session_state:
1136
  tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"])
1137
 
1138
  # Add this code in tab1 before the model file uploader
 
 
1139
  with tab1:
1140
  st.markdown("<h3 style='color: #a78bfa;'>Load Segmentation Model</h3>", unsafe_allow_html=True)
1141
 
1142
- # Add model type selection
1143
- model_type = st.selectbox(
1144
- "Select model architecture",
1145
- ["U-Net", "DeepLabV3+", "SegNet"],
1146
- index=0,
1147
- help="Select the architecture of the model you're uploading"
1148
- )
 
1149
 
1150
- # Update the model type in the segmentation object
1151
- st.session_state.segmentation.model_type = model_type.lower().replace('-', '')
1152
 
 
1153
  col1, col2 = st.columns([3, 1])
1154
 
1155
  with col1:
1156
- model_file = st.file_uploader("Upload model weights (.h5 file)", type=["h5"])
1157
-
 
 
 
 
1158
 
1159
  with col2:
1160
  st.markdown("<br>", unsafe_allow_html=True)
1161
  if st.button("Load Model", key="load_model_btn"):
1162
- if model_file is not None:
1163
  with st.spinner("Loading model..."):
1164
- # Save the uploaded model to a temporary file
1165
- with tempfile.NamedTemporaryFile(delete=False, suffix='.h5') as tmp:
1166
- tmp.write(model_file.getvalue())
1167
- model_path = tmp.name
1168
-
1169
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
1170
  # Load the model
1171
  st.session_state.segmentation.load_trained_model(model_path)
1172
  st.session_state.model_loaded = True
1173
- st.success("Model loaded successfully!")
1174
  except Exception as e:
1175
  st.error(f"Error loading model: {str(e)}")
1176
- finally:
1177
- # Clean up the temporary file
1178
- os.unlink(model_path)
1179
  else:
1180
- st.error("Please upload a model file (.h5)")
1181
-
1182
-
1183
- if st.session_state.model_loaded:
1184
- st.markdown("<div class='card'>", unsafe_allow_html=True)
1185
- st.markdown("<h4 style='color: #a78bfa;'>Model Information</h4>", unsafe_allow_html=True)
1186
-
1187
- col1, col2, col3 = st.columns(3)
1188
- with col1:
1189
- st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
1190
- # Display the correct model architecture based on the detected model type
1191
- model_arch_map = {
1192
- 'unet': "U-Net",
1193
- 'deeplabv3plus': "DeepLabV3+",
1194
- 'segnet': "SegNet"
1195
- }
1196
- model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown")
1197
- st.markdown(f"<p class='metric-value'>{model_arch}</p>", unsafe_allow_html=True)
1198
- st.markdown("<p class='metric-label'>Architecture</p>", unsafe_allow_html=True)
1199
- st.markdown("</div>", unsafe_allow_html=True)
1200
-
1201
-
1202
- with col2:
1203
- st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
1204
- st.markdown("<p class='metric-value'>11</p>", unsafe_allow_html=True)
1205
- st.markdown("<p class='metric-label'>Land Cover Classes</p>", unsafe_allow_html=True)
1206
- st.markdown("</div>", unsafe_allow_html=True)
1207
-
1208
- with col3:
1209
- st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
1210
- st.markdown("<p class='metric-value'>256 x 256</p>", unsafe_allow_html=True)
1211
- st.markdown("<p class='metric-label'>Input Size</p>", unsafe_allow_html=True)
1212
- st.markdown("</div>", unsafe_allow_html=True)
1213
-
1214
- # Display legend
1215
- st.markdown("<h4 style='color: #a78bfa; margin-top: 20px;'>Land Cover Classes</h4>", unsafe_allow_html=True)
1216
- legend_img = create_legend()
1217
- st.image(legend_img, use_column_width=True)
1218
-
1219
- st.markdown("</div>", unsafe_allow_html=True)
1220
- else:
1221
- st.info("Please load a model to continue.")
1222
 
1223
 
1224
  with tab2:
 
1136
  tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"])
1137
 
1138
  # Add this code in tab1 before the model file uploader
1139
+
1140
+ # Replace the model file uploader section in tab1 with this code
1141
  with tab1:
1142
  st.markdown("<h3 style='color: #a78bfa;'>Load Segmentation Model</h3>", unsafe_allow_html=True)
1143
 
1144
+ # Get list of available models in the models/ directory
1145
+ available_models = []
1146
+ try:
1147
+ for file in os.listdir('models'):
1148
+ if file.endswith('.h5'):
1149
+ available_models.append(file)
1150
+ except FileNotFoundError:
1151
+ st.error("Models directory not found. Please create a 'models' directory with .h5 model files.")
1152
 
1153
+ if not available_models:
1154
+ st.warning("No model files found in the models directory. Please add .h5 model files to the 'models' folder.")
1155
 
1156
+ # Add model selection dropdown
1157
  col1, col2 = st.columns([3, 1])
1158
 
1159
  with col1:
1160
+ selected_model = st.selectbox(
1161
+ "Select a pre-trained model",
1162
+ available_models,
1163
+ index=0 if available_models else None,
1164
+ help="Select one of the available pre-trained models"
1165
+ )
1166
 
1167
  with col2:
1168
  st.markdown("<br>", unsafe_allow_html=True)
1169
  if st.button("Load Model", key="load_model_btn"):
1170
+ if selected_model:
1171
  with st.spinner("Loading model..."):
 
 
 
 
 
1172
  try:
1173
+ # Load the selected model from the models directory
1174
+ model_path = os.path.join('models', selected_model)
1175
+
1176
+ # Determine model type from filename (optional)
1177
+ model_type = "unet" # Default
1178
+ if "deeplabv3" in selected_model.lower():
1179
+ model_type = "deeplabv3plus"
1180
+ elif "segnet" in selected_model.lower():
1181
+ model_type = "segnet"
1182
+
1183
+ # Update model type in the segmentation object
1184
+ st.session_state.segmentation.model_type = model_type
1185
+
1186
  # Load the model
1187
  st.session_state.segmentation.load_trained_model(model_path)
1188
  st.session_state.model_loaded = True
1189
+ st.success(f"Model '{selected_model}' loaded successfully!")
1190
  except Exception as e:
1191
  st.error(f"Error loading model: {str(e)}")
 
 
 
1192
  else:
1193
+ st.error("Please select a model from the dropdown")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1194
 
1195
 
1196
  with tab2: