VarunRavichander commited on
Commit
f6ce0e4
·
verified ·
1 Parent(s): 8babbd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -16
app.py CHANGED
@@ -143,26 +143,26 @@ def load_models(unet_weights_path, generator_path=None):
143
  generator = None
144
  if generator_path:
145
  try:
146
- # Try to load the generator model
147
- import tensorflow as tf
 
 
 
148
 
149
- # Option 1: Try direct loading with custom objects
150
  try:
151
- custom_objects = {'Functional': tf.keras.Model}
152
- generator = tf.keras.models.load_model(generator_path, custom_objects=custom_objects)
153
- except:
154
- # Option 2: If that fails, try loading just the weights into a compatible architecture
155
- st.warning("Attempting to recreate generator model architecture...")
156
  generator = create_pix2pix_generator()
157
- generator.load_weights(generator_path, by_name=True, skip_mismatch=True)
158
- st.success("Successfully loaded generator weights into a compatible model")
159
 
160
- except Exception as e:
161
- st.error(f"Error loading generator model: {e}")
 
 
 
162
 
163
  return unet, generator
164
 
165
 
 
166
  # Preprocess SAR data for SAR to Optical Translation
167
  def preprocess_sar_for_optical(sar_data):
168
  """Preprocess SAR data"""
@@ -201,16 +201,51 @@ def process_image(sar_image, unet_model, generator_model=None):
201
  # Get segmentation using U-Net
202
  seg_mask = unet_model.predict(sar_image)
203
 
204
- # Generate optical using segmentation if generator is available
205
  # Generate optical using segmentation if generator is available
206
  colorized = None
207
  if generator_model:
208
- colorized = generator_model.predict([sar_image, seg_mask])
209
- colorized = colorized[0]
210
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  return seg_mask[0], colorized
212
 
213
 
 
214
  # Visualize results for SAR to Optical Translation
215
  def visualize_results(sar_image, seg_mask, colorized=None):
216
  # ESA WorldCover colors
 
143
  generator = None
144
  if generator_path:
145
  try:
146
+ # Try direct loading first (works in Keras 3.x)
147
+ generator = tf.keras.models.load_model(generator_path)
148
+ except Exception as e:
149
+ st.warning(f"Direct model loading failed: {e}")
150
+ st.info("Trying alternative loading method...")
151
 
 
152
  try:
153
+ # Create a compatible generator model
 
 
 
 
154
  generator = create_pix2pix_generator()
 
 
155
 
156
+ # Try to load weights
157
+ generator.load_weights(generator_path, by_name=True, skip_mismatch=True)
158
+ st.success("Loaded generator weights into a compatible model")
159
+ except Exception as e2:
160
+ st.error(f"Alternative loading also failed: {e2}")
161
 
162
  return unet, generator
163
 
164
 
165
+
166
  # Preprocess SAR data for SAR to Optical Translation
167
  def preprocess_sar_for_optical(sar_data):
168
  """Preprocess SAR data"""
 
201
  # Get segmentation using U-Net
202
  seg_mask = unet_model.predict(sar_image)
203
 
 
204
  # Generate optical using segmentation if generator is available
205
  colorized = None
206
  if generator_model:
207
+ try:
208
+ colorized = generator_model.predict([sar_image, seg_mask])
209
+ colorized = colorized[0]
210
+
211
+ # Check if output is valid
212
+ if np.var(colorized) < 0.01:
213
+ st.warning("Generator produced low-variance output. Using segmentation-based colorization instead.")
214
+ colorized = None
215
+ except Exception as e:
216
+ st.error(f"Error in generator prediction: {e}")
217
+
218
+ # If generator failed or produced invalid output, create a colorized version from segmentation
219
+ if colorized is None and generator_model is not None:
220
+ st.info("Creating colorized image from segmentation mask")
221
+ # Create a colorized version based on segmentation
222
+ pred_class = np.argmax(seg_mask[0], axis=-1)
223
+ colorized = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.float32)
224
+
225
+ colors = {
226
+ 0: [0/255, 100/255, 0/255], # Trees - Dark green
227
+ 1: [255/255, 165/255, 0/255], # Shrubland - Orange
228
+ 2: [144/255, 238/255, 144/255], # Grassland - Light green
229
+ 3: [255/255, 255/255, 0/255], # Cropland - Yellow
230
+ 4: [255/255, 0/255, 0/255], # Built-up - Red
231
+ 5: [139/255, 69/255, 19/255], # Bare - Brown
232
+ 6: [255/255, 255/255, 255/255], # Snow - White
233
+ 7: [0/255, 0/255, 255/255], # Water - Blue
234
+ 8: [0/255, 139/255, 139/255], # Wetland - Dark cyan
235
+ 9: [0/255, 255/255, 0/255], # Mangroves - Bright green
236
+ 10: [220/255, 220/255, 220/255] # Moss - Light grey
237
+ }
238
+
239
+ for class_idx, color in colors.items():
240
+ colorized[pred_class == class_idx] = color
241
+
242
+ # Convert to tanh range (-1 to 1) to match generator output format
243
+ colorized = colorized * 2 - 1
244
+
245
  return seg_mask[0], colorized
246
 
247
 
248
+
249
  # Visualize results for SAR to Optical Translation
250
  def visualize_results(sar_image, seg_mask, colorized=None):
251
  # ESA WorldCover colors