Update app.py
Browse files
app.py
CHANGED
|
@@ -143,26 +143,26 @@ def load_models(unet_weights_path, generator_path=None):
|
|
| 143 |
generator = None
|
| 144 |
if generator_path:
|
| 145 |
try:
|
| 146 |
-
# Try
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
# Option 1: Try direct loading with custom objects
|
| 150 |
try:
|
| 151 |
-
|
| 152 |
-
generator = tf.keras.models.load_model(generator_path, custom_objects=custom_objects)
|
| 153 |
-
except:
|
| 154 |
-
# Option 2: If that fails, try loading just the weights into a compatible architecture
|
| 155 |
-
st.warning("Attempting to recreate generator model architecture...")
|
| 156 |
generator = create_pix2pix_generator()
|
| 157 |
-
generator.load_weights(generator_path, by_name=True, skip_mismatch=True)
|
| 158 |
-
st.success("Successfully loaded generator weights into a compatible model")
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
return unet, generator
|
| 164 |
|
| 165 |
|
|
|
|
| 166 |
# Preprocess SAR data for SAR to Optical Translation
|
| 167 |
def preprocess_sar_for_optical(sar_data):
|
| 168 |
"""Preprocess SAR data"""
|
|
@@ -201,16 +201,51 @@ def process_image(sar_image, unet_model, generator_model=None):
|
|
| 201 |
# Get segmentation using U-Net
|
| 202 |
seg_mask = unet_model.predict(sar_image)
|
| 203 |
|
| 204 |
-
# Generate optical using segmentation if generator is available
|
| 205 |
# Generate optical using segmentation if generator is available
|
| 206 |
colorized = None
|
| 207 |
if generator_model:
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
return seg_mask[0], colorized
|
| 212 |
|
| 213 |
|
|
|
|
| 214 |
# Visualize results for SAR to Optical Translation
|
| 215 |
def visualize_results(sar_image, seg_mask, colorized=None):
|
| 216 |
# ESA WorldCover colors
|
|
|
|
| 143 |
generator = None
|
| 144 |
if generator_path:
|
| 145 |
try:
|
| 146 |
+
# Try direct loading first (works in Keras 3.x)
|
| 147 |
+
generator = tf.keras.models.load_model(generator_path)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
st.warning(f"Direct model loading failed: {e}")
|
| 150 |
+
st.info("Trying alternative loading method...")
|
| 151 |
|
|
|
|
| 152 |
try:
|
| 153 |
+
# Create a compatible generator model
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
generator = create_pix2pix_generator()
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
# Try to load weights
|
| 157 |
+
generator.load_weights(generator_path, by_name=True, skip_mismatch=True)
|
| 158 |
+
st.success("Loaded generator weights into a compatible model")
|
| 159 |
+
except Exception as e2:
|
| 160 |
+
st.error(f"Alternative loading also failed: {e2}")
|
| 161 |
|
| 162 |
return unet, generator
|
| 163 |
|
| 164 |
|
| 165 |
+
|
| 166 |
# Preprocess SAR data for SAR to Optical Translation
|
| 167 |
def preprocess_sar_for_optical(sar_data):
|
| 168 |
"""Preprocess SAR data"""
|
|
|
|
| 201 |
# Get segmentation using U-Net
|
| 202 |
seg_mask = unet_model.predict(sar_image)
|
| 203 |
|
|
|
|
| 204 |
# Generate optical using segmentation if generator is available
|
| 205 |
colorized = None
|
| 206 |
if generator_model:
|
| 207 |
+
try:
|
| 208 |
+
colorized = generator_model.predict([sar_image, seg_mask])
|
| 209 |
+
colorized = colorized[0]
|
| 210 |
+
|
| 211 |
+
# Check if output is valid
|
| 212 |
+
if np.var(colorized) < 0.01:
|
| 213 |
+
st.warning("Generator produced low-variance output. Using segmentation-based colorization instead.")
|
| 214 |
+
colorized = None
|
| 215 |
+
except Exception as e:
|
| 216 |
+
st.error(f"Error in generator prediction: {e}")
|
| 217 |
+
|
| 218 |
+
# If generator failed or produced invalid output, create a colorized version from segmentation
|
| 219 |
+
if colorized is None and generator_model is not None:
|
| 220 |
+
st.info("Creating colorized image from segmentation mask")
|
| 221 |
+
# Create a colorized version based on segmentation
|
| 222 |
+
pred_class = np.argmax(seg_mask[0], axis=-1)
|
| 223 |
+
colorized = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.float32)
|
| 224 |
+
|
| 225 |
+
colors = {
|
| 226 |
+
0: [0/255, 100/255, 0/255], # Trees - Dark green
|
| 227 |
+
1: [255/255, 165/255, 0/255], # Shrubland - Orange
|
| 228 |
+
2: [144/255, 238/255, 144/255], # Grassland - Light green
|
| 229 |
+
3: [255/255, 255/255, 0/255], # Cropland - Yellow
|
| 230 |
+
4: [255/255, 0/255, 0/255], # Built-up - Red
|
| 231 |
+
5: [139/255, 69/255, 19/255], # Bare - Brown
|
| 232 |
+
6: [255/255, 255/255, 255/255], # Snow - White
|
| 233 |
+
7: [0/255, 0/255, 255/255], # Water - Blue
|
| 234 |
+
8: [0/255, 139/255, 139/255], # Wetland - Dark cyan
|
| 235 |
+
9: [0/255, 255/255, 0/255], # Mangroves - Bright green
|
| 236 |
+
10: [220/255, 220/255, 220/255] # Moss - Light grey
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
for class_idx, color in colors.items():
|
| 240 |
+
colorized[pred_class == class_idx] = color
|
| 241 |
+
|
| 242 |
+
# Convert to tanh range (-1 to 1) to match generator output format
|
| 243 |
+
colorized = colorized * 2 - 1
|
| 244 |
+
|
| 245 |
return seg_mask[0], colorized
|
| 246 |
|
| 247 |
|
| 248 |
+
|
| 249 |
# Visualize results for SAR to Optical Translation
|
| 250 |
def visualize_results(sar_image, seg_mask, colorized=None):
|
| 251 |
# ESA WorldCover colors
|