import streamlit as st import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.layers import * from tensorflow.keras.optimizers import Adam import cv2 import os import numpy as np import matplotlib.pyplot as plt from PIL import Image import io import base64 import tempfile import zipfile import random import time import rasterio from rasterio.errors import RasterioIOError import h5py import json # Set page configuration st.set_page_config( page_title="SAR Image Colorization", page_icon="🛰", layout="wide" ) def display_image(image_path): """Display an image with proper handling for different formats""" try: if os.path.exists(image_path): if image_path.lower().endswith(('.tif', '.tiff')): # Use rasterio for TIF files try: with rasterio.open(image_path) as src: img_data = src.read(1) # Read first band for single-band images # For multi-band images if src.count > 1: # For RGB images if src.count >= 3: img_data = np.dstack([src.read(i) for i in range(1, 4)]) else: # For 2-band images, duplicate the second band img_data = np.dstack([src.read(1), src.read(2), src.read(2)]) else: # For single-band images, create an RGB image img_data = np.dstack([img_data, img_data, img_data]) # Normalize for display if img_data.dtype != np.uint8: img_data = (img_data - np.min(img_data)) / (np.max(img_data) - np.min(img_data)) * 255 img_data = img_data.astype(np.uint8) st.image(img_data, use_container_width=True) except Exception as rasterio_error: # Fall back to PIL try: img = Image.open(image_path) st.image(img, use_container_width=True) except Exception as pil_error: st.error(f"Failed to load image: {str(pil_error)}") else: # Use PIL for other formats img = Image.open(image_path) st.image(img, use_container_width=True) else: st.info(f"Image file not found: {image_path}") except Exception as e: st.error(f"Error loading image: {str(e)}") # ==================== UTILITY FUNCTIONS ==================== # GPU setup for SAR to Optical Translation @st.cache_resource def setup_gpu(): gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) return f"GPU setup complete. Found {len(gpus)} GPU(s)." return "No GPUs found. Running on CPU." # ESA WorldCover colors dictionary - used in multiple functions def get_esa_colors(): return { 0: [0, 100, 0], # Trees - Dark green 1: [255, 165, 0], # Shrubland - Orange 2: [144, 238, 144], # Grassland - Light green 3: [255, 255, 0], # Cropland - Yellow 4: [255, 0, 0], # Built-up - Red 5: [139, 69, 19], # Bare - Brown 6: [255, 255, 255], # Snow - White 7: [0, 0, 255], # Water - Blue 8: [0, 139, 139], # Wetland - Dark cyan 9: [0, 255, 0], # Mangroves - Bright green 10: [220, 220, 220] # Moss - Light grey } # When visualizing ground truth, use the same color mapping as for predictions def visualize_with_ground_truth(sar_image, ground_truth, prediction): """Visualize SAR image with ground truth and prediction using ESA WorldCover colors""" # ESA WorldCover colors colors = get_esa_colors() # Convert prediction to color image pred_class = np.argmax(prediction[0], axis=-1) colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) for class_idx, color in colors.items(): colored_pred[pred_class == class_idx] = color # Convert ground truth to color image using the same color scheme gt_class = ground_truth[:,:,0].astype(np.int32) # Normalize ground truth to match prediction classes if needed if np.max(gt_class) > 10: # If using ESA WorldCover values # Map ESA values to 0-10 indices gt_mapped = np.zeros_like(gt_class) class_values = sorted(st.session_state.segmentation.class_definitions.values()) for i, val in enumerate(class_values): gt_mapped[gt_class == val] = i gt_class = gt_mapped colored_gt = np.zeros((gt_class.shape[0], gt_class.shape[1], 3), dtype=np.uint8) for class_idx, color in colors.items(): colored_gt[gt_class == class_idx] = color # Create overlay for SAR with prediction sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2) # Normalize to 0-255 for visualization sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8) overlay = cv2.addWeighted( sar_rgb, 0.7, colored_pred, 0.3, 0 ) # Set background color based on theme bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' text_color = 'white' if st.session_state.theme == 'dark' else 'black' # Create figure fig, axes = plt.subplots(1, 4, figsize=(16, 4)) # Original SAR axes[0].imshow(sar_rgb, cmap='gray') axes[0].set_title('Original SAR', color=text_color) axes[0].axis('off') # Ground Truth axes[1].imshow(colored_gt) axes[1].set_title('Ground Truth', color=text_color) axes[1].axis('off') # Prediction axes[2].imshow(colored_pred) axes[2].set_title('Prediction', color=text_color) axes[2].axis('off') # Overlay axes[3].imshow(overlay) axes[3].set_title('Colorized Output', color=text_color) axes[3].axis('off') # Set background color fig.patch.set_facecolor(bg_color) for ax in axes: ax.set_facecolor(bg_color) plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') buf.seek(0) plt.close(fig) return buf, colored_gt, colored_pred, overlay # Load models for SAR to Optical Translation @st.cache_resource def load_models(unet_weights_path, generator_path=None): # Load U-Net model unet = get_unet(input_shape=(256, 256, 1), classes=11) unet.load_weights(unet_weights_path) # Load generator model if path is provided generator = None if generator_path: try: generator = tf.keras.models.load_model(generator_path) except Exception as e: st.error(f"Error loading generator model: {e}") return unet, generator # Preprocess SAR data for SAR to Optical Translation def preprocess_sar_for_optical(sar_data): """Preprocess SAR data""" # Data is assumed to be in dB scale sar_clipped = np.clip(sar_data, -50, 20) sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 return sar_normalized # Load SAR image for SAR to Optical Translation def load_sar_image(file, img_size=(256, 256)): # Create a temporary file to save the uploaded file with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file: tmp_file.write(file.getbuffer()) tmp_file_path = tmp_file.name try: with rasterio.open(tmp_file_path) as src: image = src.read(1) image = cv2.resize(image, img_size) image = np.expand_dims(image, axis=-1) # Preprocess the image image = preprocess_sar_for_optical(image) return np.expand_dims(image, axis=0), image except Exception as e: st.error(f"Error loading SAR image: {e}") return None, None finally: # Clean up the temporary file os.unlink(tmp_file_path) # Process image with models for SAR to Optical Translation def process_image(sar_image, unet_model, generator_model=None): # Get segmentation using U-Net seg_mask = unet_model.predict(sar_image) # Generate optical using segmentation if generator is available colorized = None if generator_model: colorized = generator_model.predict([sar_image, seg_mask]) colorized = colorized[0] return seg_mask[0], colorized # Visualize results for SAR to Optical Translation def visualize_results(sar_image, seg_mask, colorized=None): # ESA WorldCover colors colors = get_esa_colors() # Convert prediction to color image pred_class = np.argmax(seg_mask, axis=-1) colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) for class_idx, color in colors.items(): colored_pred[pred_class == class_idx] = color # Create overlay sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2) # Normalize to 0-255 for visualization sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8) overlay = cv2.addWeighted( sar_rgb, 0.7, colored_pred, 0.3, 0 ) return sar_rgb, colored_pred, overlay, colorized # Load model with weights - handles different model loading scenarios def load_model_with_weights(model_path): """Load a model directly from an H5 file, preserving the original architecture""" # If model_path is a filename without path, prepend the models directory if not os.path.dirname(model_path) and not model_path.startswith('models/'): model_path = os.path.join('models', os.path.basename(model_path)) try: # Try to load the complete model (architecture + weights) # For Keras 3 compatibility import tensorflow as tf keras_version = tf.keras.__version__[0] if keras_version == '3': # For Keras 3, try to load with custom_objects to handle compatibility issues custom_objects = { 'BilinearUpsampling': BilinearUpsampling # Make sure this class is defined } model = tf.keras.models.load_model(model_path, compile=False, custom_objects=custom_objects) else: # For older Keras versions model = tf.keras.models.load_model(model_path, compile=False) print("Loaded complete model with architecture") return model except Exception as e: print(f"Could not load complete model: {str(e)}") print("Attempting to load just the weights into a matching architecture...") # Try to inspect the model file to determine architecture try: with h5py.File(model_path, 'r') as f: model_config = None if 'model_config' in f.attrs: model_config = json.loads(f.attrs['model_config'].decode('utf-8')) # If we found a model config, try to recreate it if model_config: try: model = tf.keras.models.model_from_json(json.dumps(model_config)) model.load_weights(model_path) print("Successfully loaded model from config and weights") return model except Exception as e2: print(f"Failed to load from config: {str(e2)}") except Exception as e3: print(f"Failed to inspect model file: {str(e3)}") # If all else fails, create a new model and try to load weights try: # Create a new model based on the model_type in session state if st.session_state.segmentation.model_type == 'unet': model = get_unet( input_shape=(256, 256, 1), drop_rate=0.3, classes=11 ) elif st.session_state.segmentation.model_type == 'deeplabv3plus': model = DeepLabV3Plus( input_shape=(256, 256, 1), classes=11 ) elif st.session_state.segmentation.model_type == 'segnet': model = SegNet( input_shape=(256, 256, 1), classes=11 ) # Try to load weights with skip_mismatch model.load_weights(model_path, by_name=True, skip_mismatch=True) print("Created new model and loaded compatible weights") return model except Exception as e4: print(f"Failed to create new model and load weights: {str(e4)}") # If all else fails, return None return None # Create a legend for the land cover classes def create_legend(): """Create a legend for the land cover classes""" colors = { 'Trees': [0, 100, 0], 'Shrubland': [255, 165, 0], 'Grassland': [144, 238, 144], 'Cropland': [255, 255, 0], 'Built-up': [255, 0, 0], 'Bare': [139, 69, 19], 'Snow': [255, 255, 255], 'Water': [0, 0, 255], 'Wetland': [0, 139, 139], 'Mangroves': [0, 255, 0], 'Moss': [220, 220, 220] } # Set background color based on theme bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' text_color = 'white' if st.session_state.theme == 'dark' else 'black' fig, ax = plt.subplots(figsize=(8, 4)) fig.patch.set_facecolor(bg_color) ax.set_facecolor(bg_color) # Create color patches for i, (class_name, color) in enumerate(colors.items()): ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color])) ax.text(0.7, i + 0.4, class_name, color=text_color, fontsize=12) ax.set_xlim(0, 3) ax.set_ylim(-0.5, len(colors) - 0.5) ax.set_title('Land Cover Classes', color=text_color, fontsize=14) ax.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') buf.seek(0) plt.close(fig) return buf # Visualize segmentation prediction # Update the visualize_prediction function to support both themes def visualize_prediction(prediction, original_sar, figsize=(10, 4)): """Visualize segmentation prediction with ESA WorldCover colors""" # ESA WorldCover colors colors = get_esa_colors() # Convert prediction to color image pred_class = np.argmax(prediction[0], axis=-1) colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) for class_idx, color in colors.items(): colored_pred[pred_class == class_idx] = color # Create overlay sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB) overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0) # Create figure fig, axes = plt.subplots(1, 3, figsize=figsize) # Set background color based on theme bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' text_color = 'white' if st.session_state.theme == 'dark' else 'black' # Original SAR axes[0].imshow(original_sar[:,:,0], cmap='gray') axes[0].set_title('Original SAR', color=text_color) axes[0].axis('off') # Prediction axes[1].imshow(colored_pred) axes[1].set_title('Prediction', color=text_color) axes[1].axis('off') # Overlay axes[2].imshow(overlay) axes[2].set_title('Colorized Output', color=text_color) axes[2].axis('off') # Set background color fig.patch.set_facecolor(bg_color) for ax in axes: ax.set_facecolor(bg_color) plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') buf.seek(0) plt.close(fig) return buf # ==================== MODEL DEFINITIONS ==================== # Define the U-Net model def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11): inputs = Input(input_shape) # Encoder conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) batch1_1 = BatchNormalization()(conv1_1) conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1) batch1_2 = BatchNormalization()(conv1_2) pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2) conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) batch2_1 = BatchNormalization()(conv2_1) conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1) batch2_2 = BatchNormalization()(conv2_2) pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2) conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) batch3_1 = BatchNormalization()(conv3_1) conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1) batch3_2 = BatchNormalization()(conv3_2) pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2) conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) batch4_1 = BatchNormalization()(conv4_1) conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1) batch4_2 = BatchNormalization()(conv4_2) drop4 = Dropout(drop_rate)(batch4_2) pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) # Bridge conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) batch5_1 = BatchNormalization()(conv5_1) conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1) batch5_2 = BatchNormalization()(conv5_2) drop5 = Dropout(drop_rate)(batch5_2) # Decoder up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5)) merge6 = concatenate([drop4, up6]) conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) batch6_1 = BatchNormalization()(conv6_1) conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1) batch6_2 = BatchNormalization()(conv6_2) up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2)) merge7 = concatenate([batch3_2, up7]) conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) batch7_1 = BatchNormalization()(conv7_1) conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1) batch7_2 = BatchNormalization()(conv7_2) up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2)) merge8 = concatenate([batch2_2, up8]) conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) batch8_1 = BatchNormalization()(conv8_1) conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1) batch8_2 = BatchNormalization()(conv8_2) up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2)) merge9 = concatenate([batch1_2, up9]) conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) batch9_1 = BatchNormalization()(conv9_1) conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1) batch9_2 = BatchNormalization()(conv9_2) outputs = Conv2D(classes, 1, activation='softmax')(batch9_2) model = Model(inputs=inputs, outputs=outputs) model.compile(optimizer=Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy']) return model # Custom upsampling layer for dynamic resizing class BilinearUpsampling(Layer): def __init__(self, size=(1, 1), **kwargs): super(BilinearUpsampling, self).__init__(**kwargs) self.size = size def call(self, inputs): return tf.image.resize(inputs, self.size, method='bilinear') def compute_output_shape(self, input_shape): return (input_shape[0], self.size[0], self.size[1], input_shape[3]) def get_config(self): config = super(BilinearUpsampling, self).get_config() config.update({'size': self.size}) return config # DeepLabV3+ model definition def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16): """ DeepLabV3+ model with Xception backbone Args: input_shape: Shape of input images classes: Number of classes for segmentation output_stride: Output stride for dilated convolutions (16 or 8) Returns: model: DeepLabV3+ model """ # Input layer inputs = Input(input_shape) # Ensure we're using the right dilation rates based on output_stride if output_stride == 16: atrous_rates = (6, 12, 18) elif output_stride == 8: atrous_rates = (12, 24, 36) else: raise ValueError("Output stride must be 8 or 16") # === ENCODER (BACKBONE) === # Entry block x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(64, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Xception-like blocks with dilated convolutions # Block 1 residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) x = Add()([x, residual]) # Block 2 residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu')(x) x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) x = Add()([x, residual]) # Save low_level_features for skip connection (1/4 of input size) low_level_features = x # Block 3 residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu')(x) x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) x = Add()([x, residual]) # Middle flow - modified with dilated convolutions for i in range(16): residual = x x = Activation('relu')(x) x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) x = BatchNormalization()(x) x = Add()([x, residual]) # Exit flow (modified) x = Activation('relu')(x) x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) # === ASPP (Atrous Spatial Pyramid Pooling) === # 1x1 convolution branch aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x) aspp_out1 = BatchNormalization()(aspp_out1) aspp_out1 = Activation('relu')(aspp_out1) # 3x3 dilated convolution branches with different rates aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x) aspp_out2 = BatchNormalization()(aspp_out2) aspp_out2 = Activation('relu')(aspp_out2) aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x) aspp_out3 = BatchNormalization()(aspp_out3) aspp_out3 = Activation('relu')(aspp_out3) aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x) aspp_out4 = BatchNormalization()(aspp_out4) aspp_out4 = Activation('relu')(aspp_out4) # Global pooling branch # Global pooling branch aspp_out5 = GlobalAveragePooling2D()(x) aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) # Use 1024 to match x's channels aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5) aspp_out5 = BatchNormalization()(aspp_out5) aspp_out5 = Activation('relu')(aspp_out5) # Get current shape of x _, height, width, _ = tf.keras.backend.int_shape(x) aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5) # Concatenate all ASPP branches aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5]) # Project ASPP output to 256 filters aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out) aspp_out = BatchNormalization()(aspp_out) aspp_out = Activation('relu')(aspp_out) # === DECODER === # Process low-level features from Block 2 (1/4 size) low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features) low_level_features = BatchNormalization()(low_level_features) low_level_features = Activation('relu')(low_level_features) # Upsample ASPP output by 4x to match low level features size # Get shapes for verification low_level_shape = tf.keras.backend.int_shape(low_level_features) # Upsample to match low_level_features shape x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1], low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]), interpolation='bilinear')(aspp_out) # Concatenate with low-level features x = Concatenate()([x, low_level_features]) # Final convolutions x = Conv2D(256, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(256, 3, padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Calculate upsampling size to original input size x_shape = tf.keras.backend.int_shape(x) upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2]) # Upsample to original size x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x) # Final segmentation output outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x) model = Model(inputs=inputs, outputs=outputs) model.compile(optimizer=Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy']) return model # SegNet model definition def SegNet(input_shape=(256, 256, 1), classes=11): """ SegNet model for semantic segmentation Args: input_shape: Shape of input images classes: Number of classes for segmentation Returns: model: SegNet model """ # Input layer inputs = Input(input_shape) # === ENCODER === # Encoder block 1 x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Regular MaxPooling without indices x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) # Encoder block 2 x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) # Encoder block 3 x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) # Encoder block 4 x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) # Encoder block 5 x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) # === DECODER === # Using UpSampling2D instead of MaxUnpooling since TensorFlow doesn't support it # Decoder block 5 x = UpSampling2D(size=(2, 2))(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Decoder block 4 x = UpSampling2D(size=(2, 2))(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Decoder block 3 x = UpSampling2D(size=(2, 2))(x) x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Decoder block 2 x = UpSampling2D(size=(2, 2))(x) x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Decoder block 1 x = UpSampling2D(size=(2, 2))(x) x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Output layer outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x) model = Model(inputs=inputs, outputs=outputs) return model # ==================== SAR SEGMENTATION CLASS ==================== class SARSegmentation: def __init__(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'): self.img_rows = img_rows self.img_cols = img_cols self.drop_rate = drop_rate self.num_channels = 1 # Single-pol SAR self.model = None self.model_type = model_type.lower() # ESA WorldCover class definitions self.class_definitions = { 'trees': 10, 'shrubland': 20, 'grassland': 30, 'cropland': 40, 'built_up': 50, 'bare': 60, 'snow': 70, 'water': 80, 'wetland': 90, 'mangroves': 95, 'moss': 100 } self.num_classes = len(self.class_definitions) # Class colors for visualization self.class_colors = get_esa_colors() def load_sar_data(self, file_path_or_bytes, is_bytes=False): """Load SAR data from file path or bytes""" try: if is_bytes: # Create a temporary file to use with rasterio with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp: tmp.write(file_path_or_bytes) tmp_path = tmp.name try: with rasterio.open(tmp_path) as src: sar_data = src.read(1) # Read single band sar_data = np.expand_dims(sar_data, axis=-1) except Exception as e: # If rasterio fails, try PIL img = Image.open(tmp_path).convert('L') sar_data = np.array(img) sar_data = np.expand_dims(sar_data, axis=-1) # Clean up the temporary file os.unlink(tmp_path) else: try: with rasterio.open(file_path_or_bytes) as src: sar_data = src.read(1) # Read single band sar_data = np.expand_dims(sar_data, axis=-1) except RasterioIOError: # Try to open as a regular image if rasterio fails img = Image.open(file_path_or_bytes).convert('L') sar_data = np.array(img) sar_data = np.expand_dims(sar_data, axis=-1) # Resize if needed if sar_data.shape[:2] != (self.img_rows, self.img_cols): sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows)) sar_data = np.expand_dims(sar_data, axis=-1) return sar_data except Exception as e: raise ValueError(f"Failed to load SAR data: {str(e)}") def preprocess_sar(self, sar_data): """Preprocess SAR data""" # Check if data is already normalized (0-255 range) if np.max(sar_data) <= 255 and np.min(sar_data) >= 0: # Normalize to -1 to 1 range sar_normalized = (sar_data / 127.5) - 1 else: # Assume it's in dB scale sar_clipped = np.clip(sar_data, -50, 20) sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 return sar_normalized def one_hot_encode(self, labels): """Convert ESA WorldCover labels to one-hot encoded format""" encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes)) for i, value in enumerate(sorted(self.class_definitions.values())): encoded[:, :, i] = (labels == value) return encoded def load_trained_model(self, model_path): """Load a trained model from file""" try: # If model_path is a filename without path, prepend the models directory if not os.path.dirname(model_path) and not model_path.startswith('models/'): model_path = os.path.join('models', os.path.basename(model_path)) # First try to load the complete model self.model = load_model_with_weights(model_path) if self.model is not None: has_dilated_convs = False for layer in self.model.layers: if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'): if isinstance(layer.dilation_rate, (list, tuple)): if any(rate > 1 for rate in layer.dilation_rate): has_dilated_convs = True break elif layer.dilation_rate > 1: has_dilated_convs = True break if has_dilated_convs: self.model_type = 'deeplabv3plus' print("Detected DeepLabV3+ model") # Check for SegNet architecture (typically has 5 encoder and 5 decoder blocks) elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5: self.model_type = 'segnet' print("Detected SegNet model") else: self.model_type = 'unet' print("Detected U-Net model") if self.model is None: # If that fails, try to create a model with the expected architecture if self.model_type == 'unet': self.model = get_unet( input_shape=(self.img_rows, self.img_cols, self.num_channels), drop_rate=self.drop_rate, classes=self.num_classes ) elif self.model_type == 'deeplabv3plus': self.model = DeepLabV3Plus( input_shape=(self.img_rows, self.img_cols, self.num_channels), classes=self.num_classes ) elif self.model_type == 'segnet': self.model = SegNet( input_shape=(self.img_rows, self.img_cols, self.num_channels), classes=self.num_classes ) else: raise ValueError(f"Model type {self.model_type} not supported") # Try to load weights, allowing for mismatch self.model.load_weights(model_path, by_name=True, skip_mismatch=True) # Check if any weights were loaded if not any(np.any(w) for w in self.model.get_weights()): raise ValueError("No weights were loaded. The model architecture is incompatible.") except Exception as e: raise ValueError(f"Failed to load model: {str(e)}") def predict(self, sar_data): """Predict segmentation for new SAR data""" if self.model is None: raise ValueError("Model not trained. Call train() first or load a trained model.") # Preprocess input data sar_processed = self.preprocess_sar(sar_data) # Ensure correct shape if len(sar_processed.shape) == 3: sar_processed = np.expand_dims(sar_processed, axis=0) # Make prediction prediction = self.model.predict(sar_processed) return prediction def get_colored_prediction(self, prediction): """Convert prediction to colored image""" pred_class = np.argmax(prediction[0], axis=-1) colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) for class_idx, color in self.class_colors.items(): colored_pred[pred_class == class_idx] = color return colored_pred, pred_class # ==================== UI SETUP AND STYLING ==================== # Initialize session state variables # Initialize session state variables if 'app_mode' not in st.session_state: st.session_state.app_mode = "SAR Colorization" if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False if 'segmentation' not in st.session_state: st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) if 'processed_images' not in st.session_state: st.session_state.processed_images = [] if 'theme' not in st.session_state: st.session_state.theme = "dark" # Default theme # Apply a single consistent style for the entire app def set_app_style(app_mode): if app_mode == "SAR Colorization": # Dark theme styling for SAR Colorization st.markdown( """ """, unsafe_allow_html=True ) elif app_mode == "SAR to Optical Translation": # Light theme styling for SAR to Optical Translation st.markdown( """ """, unsafe_allow_html=True ) # Create twinkling stars def create_stars_html(num_stars=100): stars_html = """
""" for i in range(num_stars): size = random.uniform(1, 3) top = random.uniform(0, 100) left = random.uniform(0, 100) duration = random.uniform(3, 8) opacity = random.uniform(0.2, 0.8) stars_html += f"""
""" stars_html += "
" return stars_html # Add logo def add_logo(logo_path='assets/logo2.png'): try: with open(logo_path, "rb") as img_file: logo_base64 = base64.b64encode(img_file.read()).decode() st.markdown( f"""
""", unsafe_allow_html=True ) except FileNotFoundError: st.warning(f"Logo file not found: {logo_path}") # ==================== MAIN APP LOGIC ==================== # Add stars to background st.markdown(create_stars_html(), unsafe_allow_html=True) # Add app mode selector to sidebar with st.sidebar: st.image('assets/logo2.png', width=150) # Add the mode selector st.title("Applications") app_mode = st.radio( "Select Application", ["SAR Colorization", "SAR to Optical Translation"] ) # Make sure this line is present to update the session state st.session_state.app_mode = app_mode # Add theme selector st.markdown("---") st.title("Appearance") theme = st.radio( "Select Theme", ["Dark", "Light"] ) set_app_style(st.session_state.app_mode) # Make sure theme is properly stored in lowercase if theme.lower() != st.session_state.theme: st.session_state.theme = theme.lower() st.rerun() # Force a rerun to apply the new theme st.markdown("---") # Sidebar content for SAR Colorization app if st.session_state.app_mode == "SAR Colorization": st.title("About") st.markdown(""" ### SAR Image Colorization This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes. #### Features: - Load pre-trained U-Net,DeepLabV3+ or SegNet models - Process single SAR images - Batch process multiple images - Visualize Pixel Level Classification with ESA WorldCover color scheme #### Developed by: Varun & Mokshyagna (NRSC, ISRO) #### Technologies: - TensorFlow/Keras - Streamlit - Rasterio - OpenCV #### Version: 1.0.0 """) # Sidebar content for the SAR to Optical app elif st.session_state.app_mode == "SAR to Optical Translation": st.header("Model Configuration") # Predefined model paths unet_weights_path = "models/unet_model.h5" generator_path = "models/final_generator.keras" # Display the paths that will be used st.info(f"U-Net Weights Path: {unet_weights_path}") use_generator = st.checkbox("Use Generator Model for Colorization", value=True) if use_generator: st.info(f"Generator Model Path: {generator_path}") else: generator_path = None # Load models button if st.button("Load Models"): with st.spinner("Loading models..."): gpu_status = setup_gpu() st.info(gpu_status) try: unet_model, generator_model = load_models(unet_weights_path, generator_path if use_generator else None) st.session_state['unet_model'] = unet_model st.session_state['generator_model'] = generator_model st.success("Models loaded successfully!") except Exception as e: st.error(f"Error loading models: {e}") # Class information st.header("ESA WorldCover Classes") class_info = { 'Trees': [0, 100, 0], 'Shrubland': [255, 165, 0], 'Grassland': [144, 238, 144], 'Cropland': [255, 255, 0], 'Built-up': [255, 0, 0], 'Bare': [139, 69, 19], 'Snow': [255, 255, 255], 'Water': [0, 0, 255], 'Wetland': [0, 139, 139], 'Mangroves': [0, 255, 0], 'Moss': [220, 220, 220] } for class_name, color in class_info.items(): st.markdown( f'
' f'
' f'{class_name}' f'
', unsafe_allow_html=True ) st.markdown("---") st.markdown("© 2025 | All Rights Reserved") # Main content area - conditional rendering based on app mode if st.session_state.app_mode == "SAR Colorization": # SAR Colorization app st.markdown("""

SAR Image Colorization

Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning

""", unsafe_allow_html=True) # Create a card container st.markdown("
", unsafe_allow_html=True) # Create tabs # Create tabs tab1, tab2, tab3, tab4 = st.tabs(["📥 Load Model", "🖼️ Process Single Image", "📁 Process Multiple Images", "🔍 Sample Images"]) # Tab 1: Load Model with tab1: st.markdown("

Load Segmentation Model

", unsafe_allow_html=True) # Add model type selection model_type = st.selectbox( "Select model architecture", ["U-Net", "DeepLabV3+", "SegNet"], index=0, help="Select the architecture of the model to load" ) # Update the model type in the segmentation object st.session_state.segmentation.model_type = model_type.lower().replace('-', '') # Define predefined model paths based on selected architecture model_paths = { "unet": "models/unet_model.h5", "deeplabv3+": "models/deeplabv3plus_model.h5", # Add this key to match the session state "deeplabv3plus": "models/deeplabv3plus_model.h5", # Keep this as a fallback "segnet": "models/segnet_model.h5" } selected_model_path = model_paths[st.session_state.segmentation.model_type] # Display the path that will be used st.info(f"Model will be loaded from: {selected_model_path}") # Load model button if st.button("Load Model", key="load_model_btn"): with st.spinner(f"Loading {model_type} model..."): try: # Load the model from the predefined path st.session_state.segmentation.load_trained_model(selected_model_path) st.session_state.model_loaded = True st.success("Model loaded successfully!") except Exception as e: st.error(f"Error loading model: {str(e)}") # Display model information if loaded if st.session_state.model_loaded: st.markdown("
", unsafe_allow_html=True) st.markdown("

Model Information

", unsafe_allow_html=True) col1, col2, col3 = st.columns(3) with col1: st.markdown("
", unsafe_allow_html=True) # Display the correct model architecture based on the detected model type model_arch_map = { 'unet': "U-Net", 'deeplabv3plus': "DeepLabV3+", 'segnet': "SegNet" } model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown") st.markdown(f"

{model_arch}

", unsafe_allow_html=True) st.markdown("

Architecture

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) with col2: st.markdown("
", unsafe_allow_html=True) st.markdown("

11

", unsafe_allow_html=True) st.markdown("

Land Cover Classes

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) with col3: st.markdown("
", unsafe_allow_html=True) st.markdown("

256 x 256

", unsafe_allow_html=True) st.markdown("

Input Size

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Display legend st.markdown("

Land Cover Classes

", unsafe_allow_html=True) legend_img = create_legend() st.image(legend_img, use_container_width=True) st.markdown("
", unsafe_allow_html=True) else: st.info("Please load a model to continue.") # Tab 2: Process Single Image with tab2: st.markdown("

Process Single SAR Image

", unsafe_allow_html=True) if not st.session_state.model_loaded: st.warning("Please load a model in the 'Load Model' tab first.") else: st.markdown("
", unsafe_allow_html=True) col1, col2 = st.columns(2) with col1: uploaded_file = st.file_uploader( "Upload a SAR image (.tif or common image formats)", type=["tif", "tiff", "png", "jpg", "jpeg"], key="single_sar_uploader" ) with col2: # Add ground truth upload option ground_truth_file = st.file_uploader( "Upload ground truth (optional)", type=["tif", "tiff", "png", "jpg", "jpeg"], key="single_gt_uploader" ) st.markdown("
", unsafe_allow_html=True) if uploaded_file is not None: if st.button("Process Image", key="process_single_btn"): with st.spinner("Processing image..."): # Load and process the image try: sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True) # Normalize for visualization sar_normalized = sar_data.copy() min_val = np.min(sar_normalized) max_val = np.max(sar_normalized) sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) # Make prediction prediction = st.session_state.segmentation.predict(sar_data) # Process ground truth if provided if ground_truth_file is not None: try: # Load ground truth gt_data = st.session_state.segmentation.load_sar_data(ground_truth_file.getvalue(), is_bytes=True) # Ensure SAR is properly normalized for visualization if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: # If it's in 0-255 range, normalize to -1 to 1 sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 else: # It's already normalized properly sar_for_viz = sar_normalized # Create visualization with ground truth using ESA WorldCover colors result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( sar_for_viz, gt_data, prediction ) # Display results with metrics st.markdown("

Segmentation Results with Ground Truth

", unsafe_allow_html=True) st.image(result_buf, use_container_width=True) # Calculate metrics pred_class = np.argmax(prediction[0], axis=-1) gt_class = gt_data[:,:,0].astype(np.int32) # Normalize ground truth to match prediction classes if needed if np.max(gt_class) > 10: # If using ESA WorldCover values # Map ESA values to 0-10 indices gt_mapped = np.zeros_like(gt_class) class_values = sorted(st.session_state.segmentation.class_definitions.values()) for i, val in enumerate(class_values): gt_mapped[gt_class == val] = i gt_class = gt_mapped accuracy = np.mean(pred_class == gt_class) * 100 # Display metrics st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{accuracy:.2f}%

", unsafe_allow_html=True) st.markdown("

Pixel Accuracy

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Add download button for the result btn = st.download_button( label="Download Result", data=result_buf, file_name="segmentation_result_with_gt.png", mime="image/png", key="download_single_result_with_gt" ) except Exception as e: st.error(f"Error processing ground truth: {str(e)}") # Fall back to regular visualization without ground truth result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) st.markdown("

Segmentation Results

", unsafe_allow_html=True) st.image(result_img, use_container_width=True) # Add download button for the result btn = st.download_button( label="Download Result", data=result_img, file_name="segmentation_result.png", mime="image/png", key="download_single_result" ) else: # Regular visualization without ground truth result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) st.markdown("

Segmentation Results

", unsafe_allow_html=True) st.image(result_img, use_container_width=True) # Add download button for the result btn = st.download_button( label="Download Result", data=result_img, file_name="segmentation_result.png", mime="image/png", key="download_single_result" ) except Exception as e: st.error(f"Error processing image: {str(e)}") # Tab 3: Process Multiple Images with tab3: st.markdown("

Process Multiple SAR Images

", unsafe_allow_html=True) if not st.session_state.model_loaded: st.warning("Please load a model in the 'Load Model' tab first.") else: st.markdown("
", unsafe_allow_html=True) # Add option for ground truth use_gt = st.checkbox("Include ground truth data", value=False) col1, col2 = st.columns(2) with col1: uploaded_files = st.file_uploader( "Upload SAR images or a ZIP file containing images", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True, key="batch_sar_uploader" ) # Add ground truth uploader if option is selected gt_files = None if use_gt: with col2: gt_files = st.file_uploader( "Upload ground truth images or a ZIP file (must match SAR filenames)", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True, key="batch_gt_uploader" ) st.info("Ground truth filenames should match SAR image filenames") st.markdown("
", unsafe_allow_html=True) col1, col2 = st.columns([3, 1]) with col1: max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10) with col2: st.markdown("
", unsafe_allow_html=True) process_btn = st.button("Process Images", key="process_multi_btn") if process_btn and uploaded_files: # Clear previous results st.session_state.processed_images = [] # Process uploaded files with st.spinner("Processing images..."): # Create a temporary directory to extract zip files if needed with tempfile.TemporaryDirectory() as temp_dir: # Process each uploaded file sar_image_files = [] gt_image_files = {} # Dictionary to map SAR filenames to GT filenames # Process SAR files for uploaded_file in uploaded_files: if uploaded_file.name.lower().endswith('.zip'): # Extract zip file zip_path = os.path.join(temp_dir, uploaded_file.name) with open(zip_path, 'wb') as f: f.write(uploaded_file.getvalue()) with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(os.path.join(temp_dir, 'sar')) # Find all image files in the extracted directory for root, _, files in os.walk(os.path.join(temp_dir, 'sar')): for file in files: if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): sar_image_files.append(os.path.join(root, file)) else: # Save the file to temp directory file_path = os.path.join(temp_dir, 'sar', uploaded_file.name) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'wb') as f: f.write(uploaded_file.getvalue()) sar_image_files.append(file_path) # Process ground truth files if provided if use_gt and gt_files: for gt_file in gt_files: if gt_file.name.lower().endswith('.zip'): # Extract zip file zip_path = os.path.join(temp_dir, gt_file.name) with open(zip_path, 'wb') as f: f.write(gt_file.getvalue()) with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(os.path.join(temp_dir, 'gt')) # Find all image files in the extracted directory for root, _, files in os.walk(os.path.join(temp_dir, 'gt')): for file in files: if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): # Map GT file to SAR file by filename gt_path = os.path.join(root, file) gt_image_files[os.path.basename(file)] = gt_path else: # Save the file to temp directory file_path = os.path.join(temp_dir, 'gt', gt_file.name) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'wb') as f: f.write(gt_file.getvalue()) gt_image_files[os.path.basename(gt_file.name)] = file_path # If there are too many images, randomly select a subset if len(sar_image_files) > max_images: st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.") sar_image_files = random.sample(sar_image_files, max_images) # Process each image progress_bar = st.progress(0) # Track overall metrics if ground truth is provided if use_gt and gt_image_files: overall_accuracy = [] for i, image_path in enumerate(sar_image_files): try: # Update progress progress_bar.progress((i + 1) / len(sar_image_files)) # Load and process the SAR image sar_data = st.session_state.segmentation.load_sar_data(image_path) # Normalize for visualization sar_normalized = sar_data.copy() min_val = np.min(sar_normalized) max_val = np.max(sar_normalized) sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) # Make prediction prediction = st.session_state.segmentation.predict(sar_data) # Check if we have a matching ground truth file image_basename = os.path.basename(image_path) has_gt = image_basename in gt_image_files if has_gt and use_gt: # Load ground truth gt_path = gt_image_files[image_basename] gt_data = st.session_state.segmentation.load_sar_data(gt_path) if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: # If it's in 0-255 range, normalize to -1 to 1 sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 else: # It's already normalized properly sar_for_viz = sar_normalized # Create visualization with ground truth using ESA WorldCover colors result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( sar_for_viz, gt_data, prediction ) # Calculate metrics pred_class = np.argmax(prediction[0], axis=-1) gt_class = gt_data[:,:,0].astype(np.int32) # Normalize ground truth to match prediction classes if needed if np.max(gt_class) > 10: # If using ESA WorldCover values # Map ESA values to 0-10 indices gt_mapped = np.zeros_like(gt_class) class_values = sorted(st.session_state.segmentation.class_definitions.values()) for i, val in enumerate(class_values): gt_mapped[gt_class == val] = i gt_class = gt_mapped accuracy = np.mean(pred_class == gt_class) * 100 overall_accuracy.append(accuracy) # Add to processed images with metrics st.session_state.processed_images.append({ 'filename': os.path.basename(image_path), 'result': result_buf, 'accuracy': accuracy }) else: # Regular visualization without ground truth result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) # Add to processed images st.session_state.processed_images.append({ 'filename': os.path.basename(image_path), 'result': result_img }) except Exception as e: st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") # Clear progress bar progress_bar.empty() # Display results if st.session_state.processed_images: st.markdown("

Segmentation Results

", unsafe_allow_html=True) # Display overall metrics if ground truth was provided if use_gt and 'overall_accuracy' in locals() and overall_accuracy: avg_accuracy = np.mean(overall_accuracy) st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{avg_accuracy:.2f}%

", unsafe_allow_html=True) st.markdown("

Average Pixel Accuracy

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Create a zip file with all results zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w') as zip_file: for i, img_data in enumerate(st.session_state.processed_images): zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue()) # Add download button for all results st.download_button( label="Download All Results", data=zip_buffer.getvalue(), file_name="segmentation_results.zip", mime="application/zip", key="download_all_results" ) # Display each result for i, img_data in enumerate(st.session_state.processed_images): st.markdown(f"
Image: {img_data['filename']}
", unsafe_allow_html=True) # Display accuracy if available if 'accuracy' in img_data: st.markdown(f"

Pixel Accuracy: {img_data['accuracy']:.2f}%

", unsafe_allow_html=True) st.image(img_data['result'], use_container_width=True) st.markdown("
", unsafe_allow_html=True) else: st.warning("No images were successfully processed.") elif process_btn: st.warning("Please upload at least one image file or ZIP archive.") # Tab 4: Sample Images with tab4: st.markdown("

Sample Images

", unsafe_allow_html=True) if not st.session_state.model_loaded: st.warning("Please load a model in the 'Load Model' tab first.") else: st.markdown("
", unsafe_allow_html=True) # Get list of sample images import os sample_dir = "samples/SAR" if os.path.exists(sample_dir): sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))] else: os.makedirs(sample_dir, exist_ok=True) os.makedirs("samples/OPTICAL", exist_ok=True) os.makedirs("samples/LABELS", exist_ok=True) sample_files = [] if sample_files: # Create a dropdown to select sample images selected_sample = st.selectbox( "Select a sample image", sample_files, key="sample_selector" ) # Display the selected sample col1, col2, col3 = st.columns(3) with col1: st.subheader("SAR Image") sar_path = os.path.join("samples/SAR", selected_sample) display_image(sar_path) with col2: st.subheader("Optical Image (Ground Truth)") # Try to find matching optical image opt_path = os.path.join("samples/OPTICAL", selected_sample) if os.path.exists(opt_path): display_image(opt_path) else: st.info("No matching optical image found") # # Add this debugging code where you're trying to load the label image with col3: st.subheader("Label Image") samples_dir = "samples" # Try multiple possible label directories possible_label_dirs = [ os.path.join(samples_dir, "labels"), os.path.join(samples_dir, "label"), os.path.join(samples_dir, "LABELS"), os.path.join(samples_dir, "LABEL"), os.path.join(samples_dir, "Labels"), os.path.join(samples_dir, "Label"), os.path.join(samples_dir, "gt"), os.path.join(samples_dir, "GT"), os.path.join(samples_dir, "ground_truth"), os.path.join(samples_dir, "groundtruth") ] # Try to find the label file label_path = None base_name = os.path.splitext(selected_sample)[0] # Try different extensions in all possible directories for dir_path in possible_label_dirs: if not os.path.exists(dir_path): continue # Try exact match first exact_path = os.path.join(dir_path, selected_sample) if os.path.exists(exact_path): label_path = exact_path break # Try different extensions for ext in ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.TIF', '.TIFF', '.PNG', '.JPG', '.JPEG']: test_path = os.path.join(dir_path, base_name + ext) if os.path.exists(test_path): label_path = test_path break # Try case-insensitive match if not label_path: for file in os.listdir(dir_path): if os.path.splitext(file)[0].lower() == base_name.lower(): label_path = os.path.join(dir_path, file) break if label_path: break # Display the label image if found # Replace the current label display code with this if label_path and os.path.exists(label_path): try: # For ESA WorldCover labels, we need special handling if label_path.lower().endswith(('.tif', '.tiff')): with rasterio.open(label_path) as src: label_data = src.read(1) # Read first band # Convert ESA WorldCover labels to colored image colors = get_esa_colors() # This function should be defined in your code colored_label = np.zeros((label_data.shape[0], label_data.shape[1], 3), dtype=np.uint8) # Map ESA values to colors for class_idx, color in colors.items(): # If using ESA WorldCover values (10, 20, 30, etc.) if np.max(label_data) > 10: # Map ESA values to 0-10 indices class_values = sorted(st.session_state.segmentation.class_definitions.values()) for i, val in enumerate(class_values): if class_idx == i: colored_label[label_data == val] = color else: # Direct mapping if values are already 0-10 colored_label[label_data == class_idx] = color st.image(colored_label, use_container_width=True) else: # For regular image formats display_image(label_path) except Exception as e: st.error(f"Error displaying label image: {str(e)}") # Fallback to regular display display_image(label_path) else: st.info("No matching label image found") # Add a button to process the selected sample if st.button("Process Selected Sample", key="process_sample_btn"): with st.spinner("Processing sample image..."): try: # Load and process the SAR image sar_data = st.session_state.segmentation.load_sar_data(sar_path) # Normalize for visualization sar_normalized = sar_data.copy() min_val = np.min(sar_normalized) max_val = np.max(sar_normalized) sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) # Make prediction prediction = st.session_state.segmentation.predict(sar_data) # Check if label image exists for comparison if os.path.exists(label_path): # Load label image as ground truth gt_data = st.session_state.segmentation.load_sar_data(label_path) # Ensure SAR is properly normalized for visualization if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: # If it's in 0-255 range, normalize to -1 to 1 sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 else: # It's already normalized properly sar_for_viz = sar_normalized # Create visualization with ground truth using ESA WorldCover colors result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( sar_for_viz, gt_data, prediction ) # Display results with metrics st.markdown("

Segmentation Results with Ground Truth

", unsafe_allow_html=True) st.image(result_buf, use_container_width=True) # Calculate metrics pred_class = np.argmax(prediction[0], axis=-1) gt_class = gt_data[:,:,0].astype(np.int32) # Normalize ground truth to match prediction classes if needed if np.max(gt_class) > 10: # If using ESA WorldCover values # Map ESA values to 0-10 indices gt_mapped = np.zeros_like(gt_class) class_values = sorted(st.session_state.segmentation.class_definitions.values()) for i, val in enumerate(class_values): gt_mapped[gt_class == val] = i gt_class = gt_mapped accuracy = np.mean(pred_class == gt_class) * 100 # Display metrics st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{accuracy:.2f}%

", unsafe_allow_html=True) st.markdown("

Pixel Accuracy

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Add download button for the result btn = st.download_button( label="Download Result", data=result_buf, file_name=f"sample_result_{selected_sample}.png", mime="image/png", key="download_sample_result_with_gt" ) else: # Regular visualization without ground truth result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) st.markdown("

Segmentation Results

", unsafe_allow_html=True) st.image(result_img, use_container_width=True) # Add download button for the result btn = st.download_button( label="Download Result", data=result_img, file_name=f"sample_result_{selected_sample}.png", mime="image/png", key="download_sample_result" ) except Exception as e: st.error(f"Error processing sample image: {str(e)}") else: st.info("No sample images found. Please add some images to the 'samples/SAR' directory.") # Close the card container st.markdown("
", unsafe_allow_html=True) elif st.session_state.app_mode == "SAR to Optical Translation": # SAR to Optical Translation app st.markdown("""

SAR to Optical Translation

Convert Synthetic Aperture Radar images to optical-like imagery using deep learning

""", unsafe_allow_html=True) # Create a card container st.markdown("
", unsafe_allow_html=True) # Check if models are loaded models_loaded = 'unet_model' in st.session_state if not models_loaded: st.warning("Please load the models from the sidebar first.") else: st.success("Models loaded successfully! You can now process SAR images.") # Create tabs for single image and batch processing # Create tabs for single image and batch processing tab1, tab2, tab3 = st.tabs(["Process Single Image", "Batch Processing", "Sample Images"]) with tab1: st.markdown("

Upload SAR Image

", unsafe_allow_html=True) # Create two columns for SAR and optional ground truth # Create two columns for SAR and optional ground truth col1, col2 = st.columns(2) with col1: st.markdown("
", unsafe_allow_html=True) uploaded_file = st.file_uploader( "Upload a SAR image (.tif or common image formats)", type=["tif", "tiff", "png", "jpg", "jpeg"], key="sar_optical_uploader" ) st.markdown("
", unsafe_allow_html=True) # Add ground truth upload option with col2: st.markdown("
", unsafe_allow_html=True) gt_file = st.file_uploader( "Upload ground truth optical image (optional)", type=["tif", "tiff", "png", "jpg", "jpeg"], key="optical_gt_uploader" ) st.markdown("
", unsafe_allow_html=True) if uploaded_file is not None: # Process button if st.button("Generate Optical-like Image", key="generate_optical_btn"): with st.spinner("Processing image..."): try: # Load and process the SAR image sar_batch, sar_image = load_sar_image(uploaded_file) if sar_batch is not None: # Process with models seg_mask, colorized = process_image( sar_batch, st.session_state['unet_model'], st.session_state.get('generator_model') ) # Visualize results sar_rgb, colored_pred, overlay, colorized_img = visualize_results( sar_image, seg_mask, colorized ) # Display results st.header("Results") # If ground truth is provided, include it in visualization if gt_file is not None: try: # Load ground truth image with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file: tmp_file.write(gt_file.getbuffer()) tmp_file_path = tmp_file.name try: # Try to open with rasterio with rasterio.open(tmp_file_path) as src: gt_image = src.read() # Debug info st.info(f"Ground truth shape: {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}") if gt_image.shape[0] == 3: # RGB image gt_image = np.transpose(gt_image, (1, 2, 0)) else: # Single band gt_image = src.read(1) # Check if the image is all zeros or all ones if np.all(gt_image == 0) or np.all(gt_image == 1): st.warning("Ground truth image appears to be blank (all zeros or ones)") # Convert to RGB for display gt_image = np.expand_dims(gt_image, axis=-1) gt_image = np.repeat(gt_image, 3, axis=-1) except Exception as rasterio_error: st.warning(f"Rasterio failed: {str(rasterio_error)}. Trying PIL...") try: # If rasterio fails, try PIL gt_image = np.array(Image.open(tmp_file_path).convert('RGB')) # Debug info st.info(f"Ground truth shape (PIL): {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}") # Check if the image is all white if np.all(gt_image > 250): st.warning("Ground truth image appears to be all white") except Exception as pil_error: st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") raise # Clean up the temporary file os.unlink(tmp_file_path) # Resize if needed if gt_image.shape[:2] != (256, 256): gt_image = cv2.resize(gt_image, (256, 256)) # Normalize if needed - make sure values are in 0-255 range for display if gt_image.dtype != np.uint8: if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: gt_image = gt_image.astype(np.uint8) elif np.max(gt_image) <= 1.0: gt_image = (gt_image * 255).astype(np.uint8) else: # Scale to 0-255 gt_min, gt_max = np.min(gt_image), np.max(gt_image) gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) # Create 4-panel visualization with ground truth fig, axes = plt.subplots(1, 4, figsize=(16, 4)) # Original SAR axes[0].imshow(sar_rgb, cmap='gray') axes[0].set_title('Original SAR', color='white') axes[0].axis('off') # Ground Truth axes[1].imshow(gt_image) axes[1].set_title('Ground Truth', color='white') axes[1].axis('off') # Segmentation axes[2].imshow(colored_pred) axes[2].set_title('Segmentation', color='white') axes[2].axis('off') # Generated Image if colorized_img is not None: # Convert from -1,1 to 0,1 range colorized_display = (colorized_img * 0.5) + 0.5 axes[3].imshow(colorized_display) else: axes[3].imshow(overlay) axes[3].set_title('Generated Image', color='white') axes[3].axis('off') # Set dark background fig.patch.set_facecolor('#0a0a1f') for ax in axes: ax.set_facecolor('#0a0a1f') plt.tight_layout() # Display the figure st.pyplot(fig) # Calculate metrics if ground truth is provided if colorized_img is not None: # Normalize both images to 0-1 range for comparison colorized_norm = (colorized_img * 0.5) + 0.5 gt_norm = gt_image.astype(np.float32) / 255.0 # Calculate PSNR mse = np.mean((colorized_norm - gt_norm) ** 2) psnr = 20 * np.log10(1.0 / np.sqrt(mse)) # Calculate SSIM from skimage.metrics import structural_similarity as ssim try: # Check image dimensions and set appropriate window size min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) win_size = min(7, min_dim - (min_dim % 2) + 1) # Ensure it's odd and smaller than min dimension ssim_value = ssim( colorized_norm, gt_norm, win_size=win_size, # Explicitly set window size channel_axis=2, # Specify channel axis for RGB images data_range=1.0 ) except Exception as e: st.warning(f"Could not calculate SSIM: {str(e)}") ssim_value = 0.0 # Default value if calculation fails # Display metrics col1, col2 = st.columns(2) with col1: st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{psnr:.2f}

", unsafe_allow_html=True) st.markdown("

PSNR (dB)

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) with col2: st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{ssim_value:.4f}

", unsafe_allow_html=True) st.markdown("

SSIM

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) except Exception as e: st.error(f"Error processing ground truth: {str(e)}") # Fall back to regular visualization col1, col2, col3 = st.columns(3) with col1: st.subheader("Original SAR Image") st.image(sar_rgb, use_container_width=True) with col2: st.subheader("Predicted Segmentation") st.image(colored_pred, use_container_width=True) with col3: st.subheader("Colorized SAR") st.image(overlay, use_container_width=True) else: # Regular 3-panel visualization without ground truth col1, col2, col3 = st.columns(3) with col1: st.subheader("Original SAR Image") st.image(sar_rgb, use_container_width=True) with col2: st.subheader("Predicted Segmentation") st.image(colored_pred, use_container_width=True) with col3: st.subheader("Colorized SAR") st.image(overlay, use_container_width=True) # Display colorized image if available if colorized_img is not None: st.header("Translated Optical Image") # Convert from -1,1 to 0,1 range colorized_display = (colorized_img * 0.5) + 0.5 # Create a figure with controlled size fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(colorized_display) ax.axis('off') # Use the figure for display instead of direct image st.pyplot(fig, use_container_width=False) # Add download buttons col1, col2 = st.columns(2) with col1: # Save segmentation image seg_buf = io.BytesIO() plt.imsave(seg_buf, colored_pred, format='png') seg_buf.seek(0) st.download_button( label="Download Segmentation", data=seg_buf, file_name="segmentation.png", mime="image/png", key="download_seg" ) with col2: # Save generated image gen_buf = io.BytesIO() plt.imsave(gen_buf, colorized_display, format='png') gen_buf.seek(0) st.download_button( label="Download Optical-like Image", data=gen_buf, file_name="optical_like.png", mime="image/png", key="download_optical" ) except Exception as e: st.error(f"Error processing image: {str(e)}") # Batch processing tab with tab2: st.markdown("

Batch Process SAR Images

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Add option for ground truth use_gt = st.checkbox("Include ground truth data", value=False) col1, col2 = st.columns(2) with col1: batch_files = st.file_uploader( "Upload SAR images or a ZIP file containing images", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True, key="batch_sar_optical_uploader" ) # Add ground truth uploader if option is selected batch_gt_files = None if use_gt: with col2: batch_gt_files = st.file_uploader( "Upload ground truth optical images or a ZIP file (must match SAR filenames)", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True, key="batch_optical_gt_uploader" ) st.info("Ground truth filenames should match SAR image filenames") st.markdown("
", unsafe_allow_html=True) col1, col2 = st.columns([3, 1]) with col1: max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=5) with col2: st.markdown("
", unsafe_allow_html=True) batch_process_btn = st.button("Process Images", key="batch_process_btn") if batch_process_btn and batch_files: # Clear previous results if 'batch_results' not in st.session_state: st.session_state.batch_results = [] else: st.session_state.batch_results = [] # Process uploaded files with st.spinner("Processing images..."): # Create a temporary directory to extract zip files if needed with tempfile.TemporaryDirectory() as temp_dir: # Process each uploaded file sar_image_files = [] gt_image_files = {} # Dictionary to map SAR filenames to GT filenames # Process SAR files for uploaded_file in batch_files: if uploaded_file.name.lower().endswith('.zip'): # Extract zip file zip_path = os.path.join(temp_dir, uploaded_file.name) with open(zip_path, 'wb') as f: f.write(uploaded_file.getvalue()) with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(os.path.join(temp_dir, 'sar')) # Find all image files in the extracted directory for root, _, files in os.walk(os.path.join(temp_dir, 'sar')): for file in files: if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): sar_image_files.append(os.path.join(root, file)) else: # Save the file to temp directory file_path = os.path.join(temp_dir, 'sar', uploaded_file.name) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'wb') as f: f.write(uploaded_file.getvalue()) sar_image_files.append(file_path) # Process ground truth files if provided if use_gt and batch_gt_files: for gt_file in batch_gt_files: if gt_file.name.lower().endswith('.zip'): # Extract zip file zip_path = os.path.join(temp_dir, gt_file.name) with open(zip_path, 'wb') as f: f.write(gt_file.getvalue()) with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(os.path.join(temp_dir, 'gt')) # Find all image files in the extracted directory for root, _, files in os.walk(os.path.join(temp_dir, 'gt')): for file in files: if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): # Map GT file to SAR file by filename gt_path = os.path.join(root, file) gt_image_files[os.path.basename(file)] = gt_path else: # Save the file to temp directory file_path = os.path.join(temp_dir, 'gt', gt_file.name) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'wb') as f: f.write(gt_file.getvalue()) gt_image_files[os.path.basename(gt_file.name)] = file_path # If there are too many images, randomly select a subset if len(sar_image_files) > max_images: st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.") sar_image_files = random.sample(sar_image_files, max_images) # Process each image progress_bar = st.progress(0) # Track overall metrics if ground truth is provided if use_gt and gt_image_files: overall_psnr = [] overall_ssim = [] for i, image_path in enumerate(sar_image_files): try: # Update progress progress_bar.progress((i + 1) / len(sar_image_files)) # Load and process the SAR image with open(image_path, 'rb') as f: file_bytes = f.read() sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes)) if sar_batch is not None: # Process with models seg_mask, colorized = process_image( sar_batch, st.session_state['unet_model'], st.session_state.get('generator_model') ) # Visualize results sar_rgb, colored_pred, overlay, colorized_img = visualize_results( sar_image, seg_mask, colorized ) # Check if we have a matching ground truth file image_basename = os.path.basename(image_path) has_gt = image_basename in gt_image_files if has_gt and use_gt: # Load ground truth gt_path = gt_image_files[image_basename] try: # Try to open with rasterio with rasterio.open(gt_path) as src: gt_image = src.read() if gt_image.shape[0] == 3: # RGB image gt_image = np.transpose(gt_image, (1, 2, 0)) else: # Single band gt_image = src.read(1) # Convert to RGB for display gt_image = np.expand_dims(gt_image, axis=-1) gt_image = np.repeat(gt_image, 3, axis=-1) except Exception as rasterio_error: try: # If rasterio fails, try PIL gt_image = np.array(Image.open(gt_path).convert('RGB')) except Exception as pil_error: st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") raise # Resize if needed if gt_image.shape[:2] != (256, 256): gt_image = cv2.resize(gt_image, (256, 256)) # Normalize if needed - make sure values are in 0-255 range for display if gt_image.dtype != np.uint8: if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: gt_image = gt_image.astype(np.uint8) elif np.max(gt_image) <= 1.0: gt_image = (gt_image * 255).astype(np.uint8) else: # Scale to 0-255 gt_min, gt_max = np.min(gt_image), np.max(gt_image) gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) # Create visualization with ground truth fig, axes = plt.subplots(1, 4, figsize=(16, 4)) # Original SAR axes[0].imshow(sar_rgb, cmap='gray') axes[0].set_title('Original SAR', color='white') axes[0].axis('off') # Ground Truth axes[1].imshow(gt_image) axes[1].set_title('Ground Truth', color='white') axes[1].axis('off') # Segmentation axes[2].imshow(colored_pred) axes[2].set_title('Segmentation', color='white') axes[2].axis('off') # Generated Image if colorized_img is not None: # Convert from -1,1 to 0,1 range colorized_display = (colorized_img * 0.5) + 0.5 axes[3].imshow(colorized_display) else: axes[3].imshow(overlay) axes[3].set_title('Generated Image', color='white') axes[3].axis('off') # Set dark background fig.patch.set_facecolor('#0a0a1f') for ax in axes: ax.set_facecolor('#0a0a1f') plt.tight_layout() # Convert plot to image result_buf = io.BytesIO() plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') result_buf.seek(0) plt.close(fig) # Calculate metrics if colorized image is available metrics = {'psnr': 0.0, 'ssim': 0.0} # Default values if colorized_img is not None: try: # Normalize both images to 0-1 range for comparison colorized_norm = (colorized_img * 0.5) + 0.5 gt_norm = gt_image.astype(np.float32) / 255.0 # Calculate PSNR mse = np.mean((colorized_norm - gt_norm) ** 2) if mse > 0: psnr = 20 * np.log10(1.0 / np.sqrt(mse)) metrics['psnr'] = psnr overall_psnr.append(psnr) # Calculate SSIM with explicit window size from skimage.metrics import structural_similarity as ssim # Check image dimensions and set appropriate window size min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) win_size = min(7, min_dim - (min_dim % 2) + 1) # Ensure it's odd and smaller than min dimension ssim_value = ssim( colorized_norm, gt_norm, win_size=win_size, # Explicitly set window size channel_axis=2, # Specify channel axis for RGB images data_range=1.0 ) metrics['ssim'] = ssim_value overall_ssim.append(ssim_value) except Exception as e: st.warning(f"Could not calculate metrics for {os.path.basename(image_path)}: {str(e)}") # Save generated image for download gen_buf = io.BytesIO() if colorized_img is not None: plt.imsave(gen_buf, colorized_display, format='png') else: plt.imsave(gen_buf, overlay, format='png') gen_buf.seek(0) # Add to batch results st.session_state.batch_results.append({ 'filename': os.path.basename(image_path), 'result': result_buf, 'generated': gen_buf, 'metrics': metrics }) else: # Regular visualization without ground truth fig, axes = plt.subplots(1, 3, figsize=(12, 4)) # Original SAR axes[0].imshow(sar_rgb, cmap='gray') axes[0].set_title('Original SAR', color='white') axes[0].axis('off') # Segmentation axes[1].imshow(colored_pred) axes[1].set_title('Segmentation', color='white') axes[1].axis('off') # Generated Image if colorized_img is not None: # Convert from -1,1 to 0,1 range colorized_display = (colorized_img * 0.5) + 0.5 axes[2].imshow(colorized_display) else: axes[2].imshow(overlay) axes[2].set_title('Generated Image', color='white') axes[2].axis('off') # Set dark background fig.patch.set_facecolor('#0a0a1f') for ax in axes: ax.set_facecolor('#0a0a1f') plt.tight_layout() # Convert plot to image result_buf = io.BytesIO() plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') result_buf.seek(0) plt.close(fig) # Save generated image for download gen_buf = io.BytesIO() if colorized_img is not None: plt.imsave(gen_buf, colorized_display, format='png') else: plt.imsave(gen_buf, overlay, format='png') gen_buf.seek(0) # Add to batch results st.session_state.batch_results.append({ 'filename': os.path.basename(image_path), 'result': result_buf, 'generated': gen_buf }) except Exception as e: st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") # Clear progress bar progress_bar.empty() # Display results if st.session_state.batch_results: st.markdown("

Translation Results

", unsafe_allow_html=True) # Display overall metrics if ground truth was provided if use_gt and 'overall_psnr' in locals() and overall_psnr: avg_psnr = np.mean(overall_psnr) avg_ssim = np.mean(overall_ssim) col1, col2 = st.columns(2) with col1: st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{avg_psnr:.2f}

", unsafe_allow_html=True) st.markdown("

Average PSNR (dB)

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) with col2: st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{avg_ssim:.4f}

", unsafe_allow_html=True) st.markdown("

Average SSIM

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Create a zip file with all results zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w') as zip_file: for i, result in enumerate(st.session_state.batch_results): # Add visualization zip_file.writestr(f"result_{i+1}_{result['filename']}.png", result['result'].getvalue()) # Add generated image zip_file.writestr(f"generated_{i+1}_{result['filename']}.png", result['generated'].getvalue()) # Add download button for all results st.download_button( label="Download All Results", data=zip_buffer.getvalue(), file_name="translation_results.zip", mime="application/zip", key="download_all_translation_results" ) # Display each result for i, result in enumerate(st.session_state.batch_results): st.markdown(f"
Image: {result['filename']}
", unsafe_allow_html=True) # Display metrics if available if 'metrics' in result: col1, col2 = st.columns(2) with col1: st.markdown("
", unsafe_allow_html=True) if 'psnr' in result['metrics']: st.markdown(f"

{result['metrics']['psnr']:.2f}

", unsafe_allow_html=True) else: st.markdown("

N/A

", unsafe_allow_html=True) st.markdown("

PSNR (dB)

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) with col2: st.markdown("
", unsafe_allow_html=True) if 'ssim' in result['metrics']: st.markdown(f"

{result['metrics']['ssim']:.4f}

", unsafe_allow_html=True) else: st.markdown("

N/A

", unsafe_allow_html=True) st.markdown("

SSIM

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) st.image(result['result'], use_container_width=True) # Add download button for individual result col1, col2 = st.columns(2) with col1: st.download_button( label="Download Visualization", data=result['result'].getvalue(), file_name=f"result_{result['filename']}.png", mime="image/png", key=f"download_viz_{i}" ) with col2: st.download_button( label="Download Generated Image", data=result['generated'].getvalue(), file_name=f"generated_{result['filename']}.png", mime="image/png", key=f"download_gen_{i}" ) st.markdown("
", unsafe_allow_html=True) else: st.warning("No images were successfully processed.") elif batch_process_btn: st.warning("Please upload at least one image file or ZIP archive.") # Tab 3: Sample Images with tab3: st.markdown("

Sample Images

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Get list of sample images import os sample_dir = "samples/SAR" if os.path.exists(sample_dir): sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))] else: os.makedirs(sample_dir, exist_ok=True) os.makedirs("samples/OPTICAL", exist_ok=True) os.makedirs("samples/LABELS", exist_ok=True) sample_files = [] if sample_files and 'unet_model' in st.session_state: # Create a dropdown to select sample images selected_sample = st.selectbox( "Select a sample image", sample_files, key="optical_sample_selector" ) # Display the selected sample col1, col2 = st.columns(2) with col1: st.subheader("SAR Image") sar_path = os.path.join("samples/SAR", selected_sample) display_image(sar_path) with col2: st.subheader("Optical Image (Ground Truth)") # Try to find matching optical image opt_path = os.path.join("samples/OPTICAL", selected_sample) if os.path.exists(opt_path): display_image(opt_path) else: st.info("No matching optical image found") # Add a button to process the selected sample if st.button("Generate Optical-like Image", key="process_optical_sample_btn"): with st.spinner("Processing sample image..."): try: # Load the SAR image with open(sar_path, 'rb') as f: file_bytes = f.read() sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes)) if sar_batch is not None: # Process with models seg_mask, colorized = process_image( sar_batch, st.session_state['unet_model'], st.session_state.get('generator_model') ) # Visualize results sar_rgb, colored_pred, overlay, colorized_img = visualize_results( sar_image, seg_mask, colorized ) # Check if ground truth exists has_gt = os.path.exists(opt_path) if has_gt: # Load ground truth try: # Try to open with rasterio with rasterio.open(opt_path) as src: gt_image = src.read() if gt_image.shape[0] == 3: # RGB image gt_image = np.transpose(gt_image, (1, 2, 0)) else: # Single band gt_image = src.read(1) # Convert to RGB for display # Convert to RGB for display gt_image = np.expand_dims(gt_image, axis=-1) gt_image = np.repeat(gt_image, 3, axis=-1) except Exception as rasterio_error: try: # If rasterio fails, try PIL gt_image = np.array(Image.open(opt_path).convert('RGB')) except Exception as pil_error: st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") raise # Resize if needed if gt_image.shape[:2] != (256, 256): gt_image = cv2.resize(gt_image, (256, 256)) # Normalize if needed - make sure values are in 0-255 range for display if gt_image.dtype != np.uint8: if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: gt_image = gt_image.astype(np.uint8) elif np.max(gt_image) <= 1.0: gt_image = (gt_image * 255).astype(np.uint8) else: # Scale to 0-255 gt_min, gt_max = np.min(gt_image), np.max(gt_image) gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) # Create 4-panel visualization with ground truth fig, axes = plt.subplots(1, 4, figsize=(16, 4)) # Original SAR axes[0].imshow(sar_rgb, cmap='gray') axes[0].set_title('Original SAR', color='white') axes[0].axis('off') # Ground Truth axes[1].imshow(gt_image) axes[1].set_title('Ground Truth', color='white') axes[1].axis('off') # Segmentation axes[2].imshow(colored_pred) axes[2].set_title('Segmentation', color='white') axes[2].axis('off') # Generated Image if colorized_img is not None: # Convert from -1,1 to 0,1 range colorized_display = (colorized_img * 0.5) + 0.5 axes[3].imshow(colorized_display) else: axes[3].imshow(overlay) axes[3].set_title('Generated Image', color='white') axes[3].axis('off') # Set dark background fig.patch.set_facecolor('#0a0a1f') for ax in axes: ax.set_facecolor('#0a0a1f') plt.tight_layout() # Display the figure st.pyplot(fig) # Calculate metrics if colorized image is available if colorized_img is not None: # Normalize both images to 0-1 range for comparison colorized_norm = (colorized_img * 0.5) + 0.5 gt_norm = gt_image.astype(np.float32) / 255.0 # Calculate PSNR mse = np.mean((colorized_norm - gt_norm) ** 2) psnr = 20 * np.log10(1.0 / np.sqrt(mse)) # Calculate SSIM from skimage.metrics import structural_similarity as ssim try: # Check image dimensions and set appropriate window size min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) win_size = min(7, min_dim - (min_dim % 2) + 1) # Ensure it's odd and smaller than min dimension ssim_value = ssim( colorized_norm, gt_norm, win_size=win_size, # Explicitly set window size channel_axis=2, # Specify channel axis for RGB images data_range=1.0 ) except Exception as e: st.warning(f"Could not calculate SSIM: {str(e)}") ssim_value = 0.0 # Default value if calculation fails # Display metrics col1, col2 = st.columns(2) with col1: st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{psnr:.2f}

", unsafe_allow_html=True) st.markdown("

PSNR (dB)

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) with col2: st.markdown("
", unsafe_allow_html=True) st.markdown(f"

{ssim_value:.4f}

", unsafe_allow_html=True) st.markdown("

SSIM

", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) else: # Regular 3-panel visualization without ground truth col1, col2, col3 = st.columns(3) with col1: st.subheader("Original SAR Image") st.image(sar_rgb, use_container_width=True) with col2: st.subheader("Predicted Segmentation") st.image(colored_pred, use_container_width=True) with col3: st.subheader("Colorized SAR") if colorized_img is not None: # Convert from -1,1 to 0,1 range colorized_display = (colorized_img * 0.5) + 0.5 st.image(colorized_display, use_container_width=True) else: st.image(overlay, use_container_width=True) # Add download buttons col1, col2 = st.columns(2) with col1: # Save segmentation image seg_buf = io.BytesIO() plt.imsave(seg_buf, colored_pred, format='png') seg_buf.seek(0) st.download_button( label="Download Segmentation", data=seg_buf, file_name=f"sample_segmentation_{selected_sample}.png", mime="image/png", key="download_sample_seg" ) with col2: # Save generated image gen_buf = io.BytesIO() if colorized_img is not None: plt.imsave(gen_buf, (colorized_img * 0.5) + 0.5, format='png') else: plt.imsave(gen_buf, overlay, format='png') gen_buf.seek(0) st.download_button( label="Download Optical-like Image", data=gen_buf, file_name=f"sample_optical_{selected_sample}.png", mime="image/png", key="download_sample_optical" ) except Exception as e: st.error(f"Error processing sample image: {str(e)}") elif not sample_files: st.info("No sample images found. Please add some images to the 'samples/SAR' directory.") else: st.warning("Please load the models from the sidebar first.") # Close the card container st.markdown("
", unsafe_allow_html=True) # Footer st.markdown("""

SAR IMAGE PROCESSING | VARUN & MOKSHYAGNA

""", unsafe_allow_html=True) # ==================== UTILITY FUNCTIONS ==================== def create_stars_html(): """Create twinkling stars effect for background""" stars_html = """
""" for i in range(100): size = random.uniform(1, 3) top = random.uniform(0, 100) left = random.uniform(0, 100) duration = random.uniform(3, 8) opacity = random.uniform(0.2, 0.8) stars_html += f"""
""" stars_html += "
" return stars_html # This function is called at the beginning of the app to set up the page # Update the setup_page_style function to support both themes def setup_page_style(): """Set up the page style with CSS based on selected theme""" # Common CSS for both themes common_css = """ /* Create twinkling stars effect */ @keyframes twinkle { 0%, 100% { opacity: 0.2; } 50% { opacity: 1; } } .stars { position: fixed; top: 0; left: 0; width: 100%; height: 100%; pointer-events: none; z-index: -1; } .star { position: absolute; background-color: white; border-radius: 50%; animation: twinkle var(--duration) infinite; opacity: var(--opacity); } /* Tab styling */ .stTabs [data-baseweb="tab-list"] { gap: 24px !important; border-radius: 0.5rem; padding: 0.8rem; margin-bottom: 3rem !important; display: flex; justify-content: center !important; width: 100%; } .stTabs [data-baseweb="tab"] { height: 5rem !important; white-space: pre-wrap; border-radius: 0.5rem; font-weight: 600 !important; font-size: 1.6rem !important; padding: 0 25px !important; display: flex; align-items: center; justify-content: center; min-width: 200px !important; } /* Add more space between tab panels */ .stTabs [data-baseweb="tab-panel"] { padding-top: 3rem !important; padding-bottom: 3rem !important; } /* Button styling */ .stButton>button { border: none; border-radius: 0.5rem; padding: 0.8rem 1.5rem !important; font-weight: 500; font-size: 1.2rem !important; margin-top: 1.5rem !important; margin-bottom: 1.5rem !important; } /* Spacing */ .element-container { margin-bottom: 2.5rem !important; } h3 { margin-top: 3rem !important; margin-bottom: 2rem !important; font-size: 1.8rem !important; } h4 { margin-top: 2.5rem !important; margin-bottom: 1.5rem !important; font-size: 1.5rem !important; } h5 { margin-top: 2rem !important; margin-bottom: 1.5rem !important; font-size: 1.3rem !important; } img { margin-top: 1.5rem !important; margin-bottom: 2.5rem !important; } .stProgress > div { margin-top: 2rem !important; margin-bottom: 2rem !important; } .stSlider { padding-top: 1.5rem !important; padding-bottom: 2.5rem !important; } .row-widget { margin-top: 1.5rem !important; margin-bottom: 2.5rem !important; } """ # Dark theme CSS dark_css = """ .stApp { background-color: #0a0a1f; color: white; } .main { background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); background-size: cover; background-position: center; background-repeat: no-repeat; background-attachment: fixed; position: relative; } .main::before { content: ""; position: absolute; top: 0; left: 0; width: 100%; height: 100%; background-color: rgba(10, 10, 31, 0.7); backdrop-filter: blur(5px); z-index: -1; } /* Title styling */ h1.title { background: linear-gradient(to right, #a78bfa, #ec4899, #3b82f6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; color: transparent; font-size: 3rem !important; font-weight: bold !important; text-align: center !important; margin-bottom: 0.5rem !important; display: block !important; position: relative !important; z-index: 10 !important; } p.subtitle { color: #bfdbfe !important; font-size: 1.2rem !important; text-align: center !important; margin-bottom: 2rem !important; position: relative !important; z-index: 10 !important; } /* Tab styling */ .stTabs [data-baseweb="tab-list"] { background-color: rgba(0, 0, 0, 0.3); } .stTabs [data-baseweb="tab"] { background-color: transparent; color: white; } .stTabs [aria-selected="true"] { background-color: rgba(147, 51, 234, 0.5) !important; transform: scale(1.05); transition: all 0.2s ease; } /* Card and box styling */ .upload-box { border: 2px dashed rgba(147, 51, 234, 0.5); border-radius: 1rem; padding: 4rem !important; text-align: center; margin-bottom: 3rem !important; } .card { background-color: rgba(0, 0, 0, 0.3); border: 1px solid rgba(147, 51, 234, 0.3); border-radius: 1rem; padding: 2.5rem !important; backdrop-filter: blur(10px); margin-bottom: 3rem !important; } /* Button styling */ .stButton>button { background: linear-gradient(to right, #7c3aed, #2563eb); color: white; } .stButton>button:hover { background: linear-gradient(to right, #6d28d9, #1d4ed8); } .download-btn { background-color: #2563eb !important; } .stSlider>div>div>div { background-color: #7c3aed; } /* Metrics styling */ .plot-container { background-color: rgba(0, 0, 0, 0.3); border-radius: 1rem; padding: 2rem !important; margin-bottom: 3rem !important; } .metric-card { background-color: rgba(0, 0, 0, 0.3); border: 1px solid rgba(147, 51, 234, 0.3); border-radius: 0.5rem; padding: 1.5rem !important; text-align: center; margin-bottom: 2rem !important; } .metric-value { font-size: 2rem !important; font-weight: bold; color: #a78bfa; } .metric-label { font-size: 1.1rem !important; color: #bfdbfe; } /* Form elements */ .stFileUploader > div { background-color: rgba(0, 0, 0, 0.3) !important; border: 1px dashed rgba(147, 51, 234, 0.5) !important; padding: 2rem !important; margin-bottom: 2rem !important; } .stSelectbox > div > div { background-color: rgba(0, 0, 0, 0.3) !important; border: 1px solid rgba(147, 51, 234, 0.3) !important; } """ # Light theme CSS # Light theme CSS - simplified with cream/whitish background and no background image # Light theme CSS - keeps dark background but uses light text light_css = """ /* Keep the same dark background */ .stApp { background-color: #0a0a1f; } .main { background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); background-size: cover; background-position: center; background-repeat: no-repeat; background-attachment: fixed; position: relative; } .main::before { content: ""; position: absolute; top: 0; left: 0; width: 100%; height: 100%; background-color: rgba(10, 10, 31, 0.7); backdrop-filter: blur(5px); z-index: -1; } /* Make all text white/light */ p, span, label, div, h1, h2, h3, h4, h5, h6, li { color: white !important; } /* Title styling - brighter gradient for better visibility */ h1.title { background: linear-gradient(to right, #d8b4fe, #f9a8d4, #93c5fd); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; color: transparent; font-size: 3rem !important; font-weight: bold !important; text-align: center !important; margin-bottom: 0.5rem !important; display: block !important; position: relative !important; z-index: 10 !important; } p.subtitle { color: #e0e7ff !important; /* Lighter purple */ font-size: 1.2rem !important; text-align: center !important; margin-bottom: 2rem !important; position: relative !important; z-index: 10 !important; } /* Tab styling - brighter for better visibility */ .stTabs [data-baseweb="tab-list"] { background-color: rgba(0, 0, 0, 0.3); } .stTabs [data-baseweb="tab"] { background-color: transparent; color: white !important; } .stTabs [aria-selected="true"] { background-color: rgba(167, 139, 250, 0.5) !important; /* Brighter purple */ transform: scale(1.05); transition: all 0.2s ease; } /* Card and box styling - brighter borders */ .upload-box { border: 2px dashed rgba(167, 139, 250, 0.7); /* Brighter purple */ border-radius: 1rem; padding: 4rem !important; text-align: center; margin-bottom: 3rem !important; } .card { background-color: rgba(0, 0, 0, 0.3); border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */ border-radius: 1rem; padding: 2.5rem !important; backdrop-filter: blur(10px); margin-bottom: 3rem !important; } /* Button styling - brighter gradient */ .stButton>button { background: linear-gradient(to right, #a78bfa, #60a5fa); color: white; } .stButton>button:hover { background: linear-gradient(to right, #8b5cf6, #3b82f6); } .download-btn { background-color: #60a5fa !important; } .stSlider>div>div>div { background-color: #a78bfa; } /* Metrics styling - brighter accents */ .plot-container { background-color: rgba(0, 0, 0, 0.3); border-radius: 1rem; padding: 2rem !important; margin-bottom: 3rem !important; } .metric-card { background-color: rgba(0, 0, 0, 0.3); border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */ border-radius: 0.5rem; padding: 1.5rem !important; text-align: center; margin-bottom: 2rem !important; } .metric-value { font-size: 2rem !important; font-weight: bold; color: #d8b4fe; /* Brighter purple */ } .metric-label { font-size: 1.1rem !important; color: #e0e7ff; /* Lighter purple */ } /* Form elements - brighter borders */ .stFileUploader > div { background-color: rgba(0, 0, 0, 0.3) !important; border: 1px dashed rgba(167, 139, 250, 0.7) !important; /* Brighter purple */ padding: 2rem !important; margin-bottom: 2rem !important; } .stSelectbox > div > div { background-color: rgba(0, 0, 0, 0.3) !important; border: 1px solid rgba(167, 139, 250, 0.5) !important; /* Brighter purple */ } /* Make sure all text inputs have white text */ input, textarea { color: white !important; } /* Ensure sidebar text is white */ .css-1d391kg, .css-1lcbmhc { color: white !important; } /* Make sure plot text is visible on dark background */ .js-plotly-plot .plotly .main-svg text { fill: white !important; } /* Keep stars visible in light theme */ .star { background-color: white; opacity: 0.8; } /* Make sure all streamlit elements have white text */ .stMarkdown, .stText, .stCode, .stTextInput, .stTextArea, .stSelectbox, .stMultiselect, .stSlider, .stCheckbox, .stRadio, .stNumber, .stDate, .stTime, .stDateInput, .stTimeInput { color: white !important; } /* Ensure dropdown options are visible */ .stSelectbox ul li { color: black !important; } """ # Apply the appropriate CSS based on the selected theme # Apply the appropriate CSS based on the selected theme if st.session_state.theme == "dark": st.markdown(f"", unsafe_allow_html=True) else: st.markdown(f"", unsafe_allow_html=True) # ==================== MAIN EXECUTION ==================== if __name__ == "__main__": # Set up page style setup_page_style() # Initialize GPU if available setup_gpu() # Initialize session state variables if 'app_mode' not in st.session_state: st.session_state.app_mode = "SAR Colorization" if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False if 'segmentation' not in st.session_state: st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) if 'processed_images' not in st.session_state: st.session_state.processed_images = []