Update app.py
Browse files
app.py
CHANGED
|
@@ -1136,89 +1136,61 @@ if 'processed_images' not in st.session_state:
|
|
| 1136 |
tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"])
|
| 1137 |
|
| 1138 |
# Add this code in tab1 before the model file uploader
|
|
|
|
|
|
|
| 1139 |
with tab1:
|
| 1140 |
st.markdown("<h3 style='color: #a78bfa;'>Load Segmentation Model</h3>", unsafe_allow_html=True)
|
| 1141 |
|
| 1142 |
-
#
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
|
|
|
| 1149 |
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
|
|
|
|
| 1153 |
col1, col2 = st.columns([3, 1])
|
| 1154 |
|
| 1155 |
with col1:
|
| 1156 |
-
|
| 1157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1158 |
|
| 1159 |
with col2:
|
| 1160 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 1161 |
if st.button("Load Model", key="load_model_btn"):
|
| 1162 |
-
if
|
| 1163 |
with st.spinner("Loading model..."):
|
| 1164 |
-
# Save the uploaded model to a temporary file
|
| 1165 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.h5') as tmp:
|
| 1166 |
-
tmp.write(model_file.getvalue())
|
| 1167 |
-
model_path = tmp.name
|
| 1168 |
-
|
| 1169 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1170 |
# Load the model
|
| 1171 |
st.session_state.segmentation.load_trained_model(model_path)
|
| 1172 |
st.session_state.model_loaded = True
|
| 1173 |
-
st.success("Model loaded successfully!")
|
| 1174 |
except Exception as e:
|
| 1175 |
st.error(f"Error loading model: {str(e)}")
|
| 1176 |
-
finally:
|
| 1177 |
-
# Clean up the temporary file
|
| 1178 |
-
os.unlink(model_path)
|
| 1179 |
else:
|
| 1180 |
-
st.error("Please
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
if st.session_state.model_loaded:
|
| 1184 |
-
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
| 1185 |
-
st.markdown("<h4 style='color: #a78bfa;'>Model Information</h4>", unsafe_allow_html=True)
|
| 1186 |
-
|
| 1187 |
-
col1, col2, col3 = st.columns(3)
|
| 1188 |
-
with col1:
|
| 1189 |
-
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
|
| 1190 |
-
# Display the correct model architecture based on the detected model type
|
| 1191 |
-
model_arch_map = {
|
| 1192 |
-
'unet': "U-Net",
|
| 1193 |
-
'deeplabv3plus': "DeepLabV3+",
|
| 1194 |
-
'segnet': "SegNet"
|
| 1195 |
-
}
|
| 1196 |
-
model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown")
|
| 1197 |
-
st.markdown(f"<p class='metric-value'>{model_arch}</p>", unsafe_allow_html=True)
|
| 1198 |
-
st.markdown("<p class='metric-label'>Architecture</p>", unsafe_allow_html=True)
|
| 1199 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
with col2:
|
| 1203 |
-
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
|
| 1204 |
-
st.markdown("<p class='metric-value'>11</p>", unsafe_allow_html=True)
|
| 1205 |
-
st.markdown("<p class='metric-label'>Land Cover Classes</p>", unsafe_allow_html=True)
|
| 1206 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
| 1207 |
-
|
| 1208 |
-
with col3:
|
| 1209 |
-
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
|
| 1210 |
-
st.markdown("<p class='metric-value'>256 x 256</p>", unsafe_allow_html=True)
|
| 1211 |
-
st.markdown("<p class='metric-label'>Input Size</p>", unsafe_allow_html=True)
|
| 1212 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
| 1213 |
-
|
| 1214 |
-
# Display legend
|
| 1215 |
-
st.markdown("<h4 style='color: #a78bfa; margin-top: 20px;'>Land Cover Classes</h4>", unsafe_allow_html=True)
|
| 1216 |
-
legend_img = create_legend()
|
| 1217 |
-
st.image(legend_img, use_column_width=True)
|
| 1218 |
-
|
| 1219 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
| 1220 |
-
else:
|
| 1221 |
-
st.info("Please load a model to continue.")
|
| 1222 |
|
| 1223 |
|
| 1224 |
with tab2:
|
|
|
|
| 1136 |
tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"])
|
| 1137 |
|
| 1138 |
# Add this code in tab1 before the model file uploader
|
| 1139 |
+
|
| 1140 |
+
# Replace the model file uploader section in tab1 with this code
|
| 1141 |
with tab1:
|
| 1142 |
st.markdown("<h3 style='color: #a78bfa;'>Load Segmentation Model</h3>", unsafe_allow_html=True)
|
| 1143 |
|
| 1144 |
+
# Get list of available models in the models/ directory
|
| 1145 |
+
available_models = []
|
| 1146 |
+
try:
|
| 1147 |
+
for file in os.listdir('models'):
|
| 1148 |
+
if file.endswith('.h5'):
|
| 1149 |
+
available_models.append(file)
|
| 1150 |
+
except FileNotFoundError:
|
| 1151 |
+
st.error("Models directory not found. Please create a 'models' directory with .h5 model files.")
|
| 1152 |
|
| 1153 |
+
if not available_models:
|
| 1154 |
+
st.warning("No model files found in the models directory. Please add .h5 model files to the 'models' folder.")
|
| 1155 |
|
| 1156 |
+
# Add model selection dropdown
|
| 1157 |
col1, col2 = st.columns([3, 1])
|
| 1158 |
|
| 1159 |
with col1:
|
| 1160 |
+
selected_model = st.selectbox(
|
| 1161 |
+
"Select a pre-trained model",
|
| 1162 |
+
available_models,
|
| 1163 |
+
index=0 if available_models else None,
|
| 1164 |
+
help="Select one of the available pre-trained models"
|
| 1165 |
+
)
|
| 1166 |
|
| 1167 |
with col2:
|
| 1168 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 1169 |
if st.button("Load Model", key="load_model_btn"):
|
| 1170 |
+
if selected_model:
|
| 1171 |
with st.spinner("Loading model..."):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1172 |
try:
|
| 1173 |
+
# Load the selected model from the models directory
|
| 1174 |
+
model_path = os.path.join('models', selected_model)
|
| 1175 |
+
|
| 1176 |
+
# Determine model type from filename (optional)
|
| 1177 |
+
model_type = "unet" # Default
|
| 1178 |
+
if "deeplabv3" in selected_model.lower():
|
| 1179 |
+
model_type = "deeplabv3plus"
|
| 1180 |
+
elif "segnet" in selected_model.lower():
|
| 1181 |
+
model_type = "segnet"
|
| 1182 |
+
|
| 1183 |
+
# Update model type in the segmentation object
|
| 1184 |
+
st.session_state.segmentation.model_type = model_type
|
| 1185 |
+
|
| 1186 |
# Load the model
|
| 1187 |
st.session_state.segmentation.load_trained_model(model_path)
|
| 1188 |
st.session_state.model_loaded = True
|
| 1189 |
+
st.success(f"Model '{selected_model}' loaded successfully!")
|
| 1190 |
except Exception as e:
|
| 1191 |
st.error(f"Error loading model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 1192 |
else:
|
| 1193 |
+
st.error("Please select a model from the dropdown")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1194 |
|
| 1195 |
|
| 1196 |
with tab2:
|