| | import streamlit as st |
| | import tensorflow as tf |
| | from tensorflow.keras.models import Model |
| | from tensorflow.keras.layers import * |
| | from tensorflow.keras.optimizers import Adam |
| | import cv2 |
| | import os |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from PIL import Image |
| | import io |
| | import base64 |
| | import tempfile |
| | import zipfile |
| | import random |
| | import time |
| | import rasterio |
| | from rasterio.errors import RasterioIOError |
| | import h5py |
| | import json |
| |
|
| | |
| | st.set_page_config( |
| | page_title="SAR Image Colorization", |
| | page_icon="🛰", |
| | layout="wide" |
| | ) |
| |
|
| |
|
| | def display_image(image_path): |
| | """Display an image with proper handling for different formats""" |
| | try: |
| | if os.path.exists(image_path): |
| | if image_path.lower().endswith(('.tif', '.tiff')): |
| | |
| | try: |
| | with rasterio.open(image_path) as src: |
| | img_data = src.read(1) |
| | |
| | |
| | if src.count > 1: |
| | |
| | if src.count >= 3: |
| | img_data = np.dstack([src.read(i) for i in range(1, 4)]) |
| | else: |
| | |
| | img_data = np.dstack([src.read(1), src.read(2), src.read(2)]) |
| | else: |
| | |
| | img_data = np.dstack([img_data, img_data, img_data]) |
| | |
| | |
| | if img_data.dtype != np.uint8: |
| | img_data = (img_data - np.min(img_data)) / (np.max(img_data) - np.min(img_data)) * 255 |
| | img_data = img_data.astype(np.uint8) |
| | |
| | st.image(img_data, use_container_width=True) |
| | except Exception as rasterio_error: |
| | |
| | try: |
| | img = Image.open(image_path) |
| | st.image(img, use_container_width=True) |
| | except Exception as pil_error: |
| | st.error(f"Failed to load image: {str(pil_error)}") |
| | else: |
| | |
| | img = Image.open(image_path) |
| | st.image(img, use_container_width=True) |
| | else: |
| | st.info(f"Image file not found: {image_path}") |
| | except Exception as e: |
| | st.error(f"Error loading image: {str(e)}") |
| |
|
| | |
| |
|
| | |
| | @st.cache_resource |
| | def setup_gpu(): |
| | gpus = tf.config.experimental.list_physical_devices('GPU') |
| | if gpus: |
| | for gpu in gpus: |
| | tf.config.experimental.set_memory_growth(gpu, True) |
| | return f"GPU setup complete. Found {len(gpus)} GPU(s)." |
| | return "No GPUs found. Running on CPU." |
| |
|
| | |
| | def get_esa_colors(): |
| | return { |
| | 0: [0, 100, 0], |
| | 1: [255, 165, 0], |
| | 2: [144, 238, 144], |
| | 3: [255, 255, 0], |
| | 4: [255, 0, 0], |
| | 5: [139, 69, 19], |
| | 6: [255, 255, 255], |
| | 7: [0, 0, 255], |
| | 8: [0, 139, 139], |
| | 9: [0, 255, 0], |
| | 10: [220, 220, 220] |
| | } |
| |
|
| | |
| | def visualize_with_ground_truth(sar_image, ground_truth, prediction): |
| | """Visualize SAR image with ground truth and prediction using ESA WorldCover colors""" |
| | |
| | colors = get_esa_colors() |
| | |
| | |
| | pred_class = np.argmax(prediction[0], axis=-1) |
| | colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| | |
| | for class_idx, color in colors.items(): |
| | colored_pred[pred_class == class_idx] = color |
| | |
| | |
| | gt_class = ground_truth[:,:,0].astype(np.int32) |
| | |
| | |
| | if np.max(gt_class) > 10: |
| | |
| | gt_mapped = np.zeros_like(gt_class) |
| | class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| | for i, val in enumerate(class_values): |
| | gt_mapped[gt_class == val] = i |
| | gt_class = gt_mapped |
| | |
| | colored_gt = np.zeros((gt_class.shape[0], gt_class.shape[1], 3), dtype=np.uint8) |
| | |
| | for class_idx, color in colors.items(): |
| | colored_gt[gt_class == class_idx] = color |
| | |
| | |
| | sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2) |
| | |
| | sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8) |
| | |
| | overlay = cv2.addWeighted( |
| | sar_rgb, |
| | 0.7, |
| | colored_pred, |
| | 0.3, |
| | 0 |
| | ) |
| | |
| | |
| | bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' |
| | text_color = 'white' if st.session_state.theme == 'dark' else 'black' |
| | |
| | |
| | fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| | |
| | |
| | axes[0].imshow(sar_rgb, cmap='gray') |
| | axes[0].set_title('Original SAR', color=text_color) |
| | axes[0].axis('off') |
| | |
| | |
| | axes[1].imshow(colored_gt) |
| | axes[1].set_title('Ground Truth', color=text_color) |
| | axes[1].axis('off') |
| | |
| | |
| | axes[2].imshow(colored_pred) |
| | axes[2].set_title('Prediction', color=text_color) |
| | axes[2].axis('off') |
| | |
| | |
| | axes[3].imshow(overlay) |
| | axes[3].set_title('Colorized Output', color=text_color) |
| | axes[3].axis('off') |
| | |
| | |
| | fig.patch.set_facecolor(bg_color) |
| | for ax in axes: |
| | ax.set_facecolor(bg_color) |
| | |
| | plt.tight_layout() |
| | |
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') |
| | buf.seek(0) |
| | plt.close(fig) |
| | |
| | return buf, colored_gt, colored_pred, overlay |
| |
|
| |
|
| | |
| | @st.cache_resource |
| | def load_models(unet_weights_path, generator_path=None): |
| | |
| | unet = get_unet(input_shape=(256, 256, 1), classes=11) |
| | unet.load_weights(unet_weights_path) |
| | |
| | |
| | generator = None |
| | if generator_path: |
| | try: |
| | generator = tf.keras.models.load_model(generator_path) |
| | except Exception as e: |
| | st.error(f"Error loading generator model: {e}") |
| | |
| | return unet, generator |
| |
|
| | |
| | def preprocess_sar_for_optical(sar_data): |
| | """Preprocess SAR data""" |
| | |
| | sar_clipped = np.clip(sar_data, -50, 20) |
| | sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 |
| | return sar_normalized |
| |
|
| | |
| | def load_sar_image(file, img_size=(256, 256)): |
| | |
| | with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file: |
| | tmp_file.write(file.getbuffer()) |
| | tmp_file_path = tmp_file.name |
| | |
| | try: |
| | with rasterio.open(tmp_file_path) as src: |
| | image = src.read(1) |
| | image = cv2.resize(image, img_size) |
| | image = np.expand_dims(image, axis=-1) |
| | |
| | |
| | image = preprocess_sar_for_optical(image) |
| | return np.expand_dims(image, axis=0), image |
| | except Exception as e: |
| | st.error(f"Error loading SAR image: {e}") |
| | return None, None |
| | finally: |
| | |
| | os.unlink(tmp_file_path) |
| |
|
| | |
| | def process_image(sar_image, unet_model, generator_model=None): |
| | |
| | seg_mask = unet_model.predict(sar_image) |
| | |
| | |
| | colorized = None |
| | if generator_model: |
| | colorized = generator_model.predict([sar_image, seg_mask]) |
| | colorized = colorized[0] |
| |
|
| | return seg_mask[0], colorized |
| |
|
| | |
| | def visualize_results(sar_image, seg_mask, colorized=None): |
| | |
| | colors = get_esa_colors() |
| | |
| | |
| | pred_class = np.argmax(seg_mask, axis=-1) |
| | colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| | |
| | for class_idx, color in colors.items(): |
| | colored_pred[pred_class == class_idx] = color |
| | |
| | |
| | sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2) |
| | |
| | sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8) |
| | |
| | overlay = cv2.addWeighted( |
| | sar_rgb, |
| | 0.7, |
| | colored_pred, |
| | 0.3, |
| | 0 |
| | ) |
| | |
| | return sar_rgb, colored_pred, overlay, colorized |
| |
|
| |
|
| | |
| | def load_model_with_weights(model_path): |
| | """Load a model directly from an H5 file, preserving the original architecture""" |
| | |
| | if not os.path.dirname(model_path) and not model_path.startswith('models/'): |
| | model_path = os.path.join('models', os.path.basename(model_path)) |
| | |
| | try: |
| | |
| | |
| | import tensorflow as tf |
| | keras_version = tf.keras.__version__[0] |
| | |
| | if keras_version == '3': |
| | |
| | custom_objects = { |
| | 'BilinearUpsampling': BilinearUpsampling |
| | } |
| | model = tf.keras.models.load_model(model_path, compile=False, custom_objects=custom_objects) |
| | else: |
| | |
| | model = tf.keras.models.load_model(model_path, compile=False) |
| | |
| | print("Loaded complete model with architecture") |
| | return model |
| | except Exception as e: |
| | print(f"Could not load complete model: {str(e)}") |
| | print("Attempting to load just the weights into a matching architecture...") |
| | |
| | |
| | try: |
| | with h5py.File(model_path, 'r') as f: |
| | model_config = None |
| | if 'model_config' in f.attrs: |
| | model_config = json.loads(f.attrs['model_config'].decode('utf-8')) |
| | |
| | |
| | if model_config: |
| | try: |
| | model = tf.keras.models.model_from_json(json.dumps(model_config)) |
| | model.load_weights(model_path) |
| | print("Successfully loaded model from config and weights") |
| | return model |
| | except Exception as e2: |
| | print(f"Failed to load from config: {str(e2)}") |
| | except Exception as e3: |
| | print(f"Failed to inspect model file: {str(e3)}") |
| | |
| | |
| | try: |
| | |
| | if st.session_state.segmentation.model_type == 'unet': |
| | model = get_unet( |
| | input_shape=(256, 256, 1), |
| | drop_rate=0.3, |
| | classes=11 |
| | ) |
| | elif st.session_state.segmentation.model_type == 'deeplabv3plus': |
| | model = DeepLabV3Plus( |
| | input_shape=(256, 256, 1), |
| | classes=11 |
| | ) |
| | elif st.session_state.segmentation.model_type == 'segnet': |
| | model = SegNet( |
| | input_shape=(256, 256, 1), |
| | classes=11 |
| | ) |
| | |
| | |
| | model.load_weights(model_path, by_name=True, skip_mismatch=True) |
| | print("Created new model and loaded compatible weights") |
| | return model |
| | except Exception as e4: |
| | print(f"Failed to create new model and load weights: {str(e4)}") |
| | |
| | |
| | return None |
| |
|
| | |
| | def create_legend(): |
| | """Create a legend for the land cover classes""" |
| | colors = { |
| | 'Trees': [0, 100, 0], |
| | 'Shrubland': [255, 165, 0], |
| | 'Grassland': [144, 238, 144], |
| | 'Cropland': [255, 255, 0], |
| | 'Built-up': [255, 0, 0], |
| | 'Bare': [139, 69, 19], |
| | 'Snow': [255, 255, 255], |
| | 'Water': [0, 0, 255], |
| | 'Wetland': [0, 139, 139], |
| | 'Mangroves': [0, 255, 0], |
| | 'Moss': [220, 220, 220] |
| | } |
| | |
| | |
| | bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' |
| | text_color = 'white' if st.session_state.theme == 'dark' else 'black' |
| | |
| | fig, ax = plt.subplots(figsize=(8, 4)) |
| | fig.patch.set_facecolor(bg_color) |
| | ax.set_facecolor(bg_color) |
| | |
| | |
| | for i, (class_name, color) in enumerate(colors.items()): |
| | ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color])) |
| | ax.text(0.7, i + 0.4, class_name, color=text_color, fontsize=12) |
| | |
| | ax.set_xlim(0, 3) |
| | ax.set_ylim(-0.5, len(colors) - 0.5) |
| | ax.set_title('Land Cover Classes', color=text_color, fontsize=14) |
| | ax.axis('off') |
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') |
| | buf.seek(0) |
| | plt.close(fig) |
| | |
| | return buf |
| |
|
| |
|
| | |
| | |
| | def visualize_prediction(prediction, original_sar, figsize=(10, 4)): |
| | """Visualize segmentation prediction with ESA WorldCover colors""" |
| | |
| | colors = get_esa_colors() |
| | |
| | |
| | pred_class = np.argmax(prediction[0], axis=-1) |
| | colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| | |
| | for class_idx, color in colors.items(): |
| | colored_pred[pred_class == class_idx] = color |
| | |
| | |
| | sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB) |
| | overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0) |
| | |
| | |
| | fig, axes = plt.subplots(1, 3, figsize=figsize) |
| | |
| | |
| | bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' |
| | text_color = 'white' if st.session_state.theme == 'dark' else 'black' |
| | |
| | |
| | axes[0].imshow(original_sar[:,:,0], cmap='gray') |
| | axes[0].set_title('Original SAR', color=text_color) |
| | axes[0].axis('off') |
| | |
| | |
| | axes[1].imshow(colored_pred) |
| | axes[1].set_title('Prediction', color=text_color) |
| | axes[1].axis('off') |
| | |
| | |
| | axes[2].imshow(overlay) |
| | axes[2].set_title('Colorized Output', color=text_color) |
| | axes[2].axis('off') |
| | |
| | |
| | fig.patch.set_facecolor(bg_color) |
| | for ax in axes: |
| | ax.set_facecolor(bg_color) |
| | |
| | plt.tight_layout() |
| | |
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') |
| | buf.seek(0) |
| | plt.close(fig) |
| | return buf |
| |
|
| | |
| |
|
| | |
| | def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11): |
| | inputs = Input(input_shape) |
| | |
| | |
| | conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) |
| | batch1_1 = BatchNormalization()(conv1_1) |
| | conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1) |
| | batch1_2 = BatchNormalization()(conv1_2) |
| | pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2) |
| |
|
| | conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) |
| | batch2_1 = BatchNormalization()(conv2_1) |
| | conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1) |
| | batch2_2 = BatchNormalization()(conv2_2) |
| | pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2) |
| |
|
| | conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) |
| | batch3_1 = BatchNormalization()(conv3_1) |
| | conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1) |
| | batch3_2 = BatchNormalization()(conv3_2) |
| | pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2) |
| |
|
| | conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) |
| | batch4_1 = BatchNormalization()(conv4_1) |
| | conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1) |
| | batch4_2 = BatchNormalization()(conv4_2) |
| | drop4 = Dropout(drop_rate)(batch4_2) |
| | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) |
| |
|
| | |
| | conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) |
| | batch5_1 = BatchNormalization()(conv5_1) |
| | conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1) |
| | batch5_2 = BatchNormalization()(conv5_2) |
| | drop5 = Dropout(drop_rate)(batch5_2) |
| |
|
| | |
| | up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5)) |
| | merge6 = concatenate([drop4, up6]) |
| | conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) |
| | batch6_1 = BatchNormalization()(conv6_1) |
| | conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1) |
| | batch6_2 = BatchNormalization()(conv6_2) |
| |
|
| | up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2)) |
| | merge7 = concatenate([batch3_2, up7]) |
| | conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) |
| | batch7_1 = BatchNormalization()(conv7_1) |
| | conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1) |
| | batch7_2 = BatchNormalization()(conv7_2) |
| |
|
| | up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2)) |
| | merge8 = concatenate([batch2_2, up8]) |
| | conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) |
| | batch8_1 = BatchNormalization()(conv8_1) |
| | conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1) |
| | batch8_2 = BatchNormalization()(conv8_2) |
| |
|
| | up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2)) |
| | merge9 = concatenate([batch1_2, up9]) |
| | conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) |
| | batch9_1 = BatchNormalization()(conv9_1) |
| | conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1) |
| | batch9_2 = BatchNormalization()(conv9_2) |
| |
|
| | outputs = Conv2D(classes, 1, activation='softmax')(batch9_2) |
| |
|
| | model = Model(inputs=inputs, outputs=outputs) |
| | model.compile(optimizer=Adam(learning_rate=1e-4), |
| | loss='categorical_crossentropy', |
| | metrics=['accuracy']) |
| | |
| | return model |
| |
|
| | |
| | class BilinearUpsampling(Layer): |
| | def __init__(self, size=(1, 1), **kwargs): |
| | super(BilinearUpsampling, self).__init__(**kwargs) |
| | self.size = size |
| |
|
| | def call(self, inputs): |
| | return tf.image.resize(inputs, self.size, method='bilinear') |
| | |
| | def compute_output_shape(self, input_shape): |
| | return (input_shape[0], self.size[0], self.size[1], input_shape[3]) |
| | |
| | def get_config(self): |
| | config = super(BilinearUpsampling, self).get_config() |
| | config.update({'size': self.size}) |
| | return config |
| |
|
| | |
| | def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16): |
| | """ |
| | DeepLabV3+ model with Xception backbone |
| | |
| | Args: |
| | input_shape: Shape of input images |
| | classes: Number of classes for segmentation |
| | output_stride: Output stride for dilated convolutions (16 or 8) |
| | |
| | Returns: |
| | model: DeepLabV3+ model |
| | """ |
| | |
| | inputs = Input(input_shape) |
| | |
| | |
| | if output_stride == 16: |
| | atrous_rates = (6, 12, 18) |
| | elif output_stride == 8: |
| | atrous_rates = (12, 24, 36) |
| | else: |
| | raise ValueError("Output stride must be 8 or 16") |
| | |
| | |
| | |
| | x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | x = Conv2D(64, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | |
| | |
| | residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x) |
| | residual = BatchNormalization()(residual) |
| | |
| | x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) |
| | x = Add()([x, residual]) |
| | |
| | |
| | residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x) |
| | residual = BatchNormalization()(residual) |
| | |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) |
| | x = Add()([x, residual]) |
| | |
| | |
| | low_level_features = x |
| | |
| | |
| | residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x) |
| | residual = BatchNormalization()(residual) |
| | |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) |
| | x = Add()([x, residual]) |
| | |
| | |
| | for i in range(16): |
| | residual = x |
| | |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | |
| | x = Add()([x, residual]) |
| | |
| | |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | |
| | |
| | |
| | aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x) |
| | aspp_out1 = BatchNormalization()(aspp_out1) |
| | aspp_out1 = Activation('relu')(aspp_out1) |
| | |
| | |
| | aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x) |
| | aspp_out2 = BatchNormalization()(aspp_out2) |
| | aspp_out2 = Activation('relu')(aspp_out2) |
| | |
| | aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x) |
| | aspp_out3 = BatchNormalization()(aspp_out3) |
| | aspp_out3 = Activation('relu')(aspp_out3) |
| | |
| | aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x) |
| | aspp_out4 = BatchNormalization()(aspp_out4) |
| | aspp_out4 = Activation('relu')(aspp_out4) |
| | |
| | |
| | |
| | aspp_out5 = GlobalAveragePooling2D()(x) |
| | aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) |
| | aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5) |
| | aspp_out5 = BatchNormalization()(aspp_out5) |
| | aspp_out5 = Activation('relu')(aspp_out5) |
| | |
| | |
| | _, height, width, _ = tf.keras.backend.int_shape(x) |
| | aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5) |
| | |
| | |
| | aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5]) |
| | |
| | |
| | aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out) |
| | aspp_out = BatchNormalization()(aspp_out) |
| | aspp_out = Activation('relu')(aspp_out) |
| | |
| | |
| | |
| | low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features) |
| | low_level_features = BatchNormalization()(low_level_features) |
| | low_level_features = Activation('relu')(low_level_features) |
| | |
| | |
| | |
| | low_level_shape = tf.keras.backend.int_shape(low_level_features) |
| | |
| | |
| | x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1], |
| | low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]), |
| | interpolation='bilinear')(aspp_out) |
| | |
| | |
| | x = Concatenate()([x, low_level_features]) |
| | |
| | |
| | x = Conv2D(256, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | x = Conv2D(256, 3, padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | |
| | x_shape = tf.keras.backend.int_shape(x) |
| | upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2]) |
| | |
| | |
| | x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x) |
| | |
| | |
| | outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x) |
| | |
| | model = Model(inputs=inputs, outputs=outputs) |
| | model.compile(optimizer=Adam(learning_rate=1e-4), |
| | loss='categorical_crossentropy', |
| | metrics=['accuracy']) |
| | |
| | return model |
| |
|
| | |
| | def SegNet(input_shape=(256, 256, 1), classes=11): |
| | """ |
| | SegNet model for semantic segmentation |
| | |
| | Args: |
| | input_shape: Shape of input images |
| | classes: Number of classes for segmentation |
| | |
| | Returns: |
| | model: SegNet model |
| | """ |
| | |
| | inputs = Input(input_shape) |
| | |
| | |
| | |
| | x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| | |
| | |
| | x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| | |
| | |
| | x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| | |
| | |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| | |
| | |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| | |
| | |
| | |
| | |
| | |
| | x = UpSampling2D(size=(2, 2))(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | |
| | x = UpSampling2D(size=(2, 2))(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | |
| | x = UpSampling2D(size=(2, 2))(x) |
| | x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | |
| | x = UpSampling2D(size=(2, 2))(x) |
| | x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | |
| | x = UpSampling2D(size=(2, 2))(x) |
| | x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| | x = BatchNormalization()(x) |
| | x = Activation('relu')(x) |
| | |
| | |
| | outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x) |
| | |
| | model = Model(inputs=inputs, outputs=outputs) |
| | |
| | return model |
| |
|
| | |
| |
|
| | class SARSegmentation: |
| | def __init__(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'): |
| | self.img_rows = img_rows |
| | self.img_cols = img_cols |
| | self.drop_rate = drop_rate |
| | self.num_channels = 1 |
| | self.model = None |
| | self.model_type = model_type.lower() |
| | |
| | |
| | self.class_definitions = { |
| | 'trees': 10, |
| | 'shrubland': 20, |
| | 'grassland': 30, |
| | 'cropland': 40, |
| | 'built_up': 50, |
| | 'bare': 60, |
| | 'snow': 70, |
| | 'water': 80, |
| | 'wetland': 90, |
| | 'mangroves': 95, |
| | 'moss': 100 |
| | } |
| | self.num_classes = len(self.class_definitions) |
| | |
| | |
| | self.class_colors = get_esa_colors() |
| |
|
| | def load_sar_data(self, file_path_or_bytes, is_bytes=False): |
| | """Load SAR data from file path or bytes""" |
| | try: |
| | if is_bytes: |
| | |
| | with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp: |
| | tmp.write(file_path_or_bytes) |
| | tmp_path = tmp.name |
| | |
| | try: |
| | with rasterio.open(tmp_path) as src: |
| | sar_data = src.read(1) |
| | sar_data = np.expand_dims(sar_data, axis=-1) |
| | except Exception as e: |
| | |
| | img = Image.open(tmp_path).convert('L') |
| | sar_data = np.array(img) |
| | sar_data = np.expand_dims(sar_data, axis=-1) |
| | |
| | |
| | os.unlink(tmp_path) |
| | else: |
| | try: |
| | with rasterio.open(file_path_or_bytes) as src: |
| | sar_data = src.read(1) |
| | sar_data = np.expand_dims(sar_data, axis=-1) |
| | except RasterioIOError: |
| | |
| | img = Image.open(file_path_or_bytes).convert('L') |
| | sar_data = np.array(img) |
| | sar_data = np.expand_dims(sar_data, axis=-1) |
| | |
| | |
| | if sar_data.shape[:2] != (self.img_rows, self.img_cols): |
| | sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows)) |
| | sar_data = np.expand_dims(sar_data, axis=-1) |
| | |
| | return sar_data |
| | except Exception as e: |
| | raise ValueError(f"Failed to load SAR data: {str(e)}") |
| |
|
| | def preprocess_sar(self, sar_data): |
| | """Preprocess SAR data""" |
| | |
| | if np.max(sar_data) <= 255 and np.min(sar_data) >= 0: |
| | |
| | sar_normalized = (sar_data / 127.5) - 1 |
| | else: |
| | |
| | sar_clipped = np.clip(sar_data, -50, 20) |
| | sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 |
| | |
| | return sar_normalized |
| |
|
| | def one_hot_encode(self, labels): |
| | """Convert ESA WorldCover labels to one-hot encoded format""" |
| | encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes)) |
| | |
| | for i, value in enumerate(sorted(self.class_definitions.values())): |
| | encoded[:, :, i] = (labels == value) |
| | |
| | return encoded |
| |
|
| | def load_trained_model(self, model_path): |
| | """Load a trained model from file""" |
| | try: |
| | |
| | if not os.path.dirname(model_path) and not model_path.startswith('models/'): |
| | model_path = os.path.join('models', os.path.basename(model_path)) |
| | |
| | |
| | self.model = load_model_with_weights(model_path) |
| | |
| | if self.model is not None: |
| | has_dilated_convs = False |
| | for layer in self.model.layers: |
| | if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'): |
| | if isinstance(layer.dilation_rate, (list, tuple)): |
| | if any(rate > 1 for rate in layer.dilation_rate): |
| | has_dilated_convs = True |
| | break |
| | elif layer.dilation_rate > 1: |
| | has_dilated_convs = True |
| | break |
| | |
| | if has_dilated_convs: |
| | self.model_type = 'deeplabv3plus' |
| | print("Detected DeepLabV3+ model") |
| | |
| | elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5: |
| | self.model_type = 'segnet' |
| | print("Detected SegNet model") |
| | else: |
| | self.model_type = 'unet' |
| | print("Detected U-Net model") |
| | |
| | if self.model is None: |
| | |
| | if self.model_type == 'unet': |
| | self.model = get_unet( |
| | input_shape=(self.img_rows, self.img_cols, self.num_channels), |
| | drop_rate=self.drop_rate, |
| | classes=self.num_classes |
| | ) |
| | elif self.model_type == 'deeplabv3plus': |
| | self.model = DeepLabV3Plus( |
| | input_shape=(self.img_rows, self.img_cols, self.num_channels), |
| | classes=self.num_classes |
| | ) |
| | elif self.model_type == 'segnet': |
| | self.model = SegNet( |
| | input_shape=(self.img_rows, self.img_cols, self.num_channels), |
| | classes=self.num_classes |
| | ) |
| | else: |
| | raise ValueError(f"Model type {self.model_type} not supported") |
| | |
| | |
| | self.model.load_weights(model_path, by_name=True, skip_mismatch=True) |
| | |
| | |
| | if not any(np.any(w) for w in self.model.get_weights()): |
| | raise ValueError("No weights were loaded. The model architecture is incompatible.") |
| | except Exception as e: |
| | raise ValueError(f"Failed to load model: {str(e)}") |
| |
|
| | def predict(self, sar_data): |
| | """Predict segmentation for new SAR data""" |
| | if self.model is None: |
| | raise ValueError("Model not trained. Call train() first or load a trained model.") |
| | |
| | |
| | sar_processed = self.preprocess_sar(sar_data) |
| | |
| | |
| | if len(sar_processed.shape) == 3: |
| | sar_processed = np.expand_dims(sar_processed, axis=0) |
| | |
| | |
| | prediction = self.model.predict(sar_processed) |
| | return prediction |
| |
|
| | def get_colored_prediction(self, prediction): |
| | """Convert prediction to colored image""" |
| | pred_class = np.argmax(prediction[0], axis=-1) |
| | colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| | |
| | for class_idx, color in self.class_colors.items(): |
| | colored_pred[pred_class == class_idx] = color |
| | |
| | return colored_pred, pred_class |
| |
|
| | |
| |
|
| | |
| | |
| | if 'app_mode' not in st.session_state: |
| | st.session_state.app_mode = "SAR Colorization" |
| | if 'model_loaded' not in st.session_state: |
| | st.session_state.model_loaded = False |
| | if 'segmentation' not in st.session_state: |
| | st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) |
| | if 'processed_images' not in st.session_state: |
| | st.session_state.processed_images = [] |
| | if 'theme' not in st.session_state: |
| | st.session_state.theme = "dark" |
| | |
| | def set_app_style(app_mode): |
| | if app_mode == "SAR Colorization": |
| | |
| | st.markdown( |
| | """ |
| | <style> |
| | .stApp { |
| | background-color: #0a0a1f; |
| | color: white; |
| | } |
| | |
| | .main { |
| | background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| | background-size: cover; |
| | background-position: center; |
| | background-repeat: no-repeat; |
| | background-attachment: fixed; |
| | position: relative; |
| | } |
| | |
| | .main::before { |
| | content: ""; |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | width: 100%; |
| | height: 100%; |
| | background-color: rgba(10, 10, 31, 0.7); |
| | backdrop-filter: blur(5px); |
| | z-index: -1; |
| | } |
| | |
| | /* Rest of your dark theme CSS */ |
| | /* ... */ |
| | </style> |
| | """, |
| | unsafe_allow_html=True |
| | ) |
| | elif app_mode == "SAR to Optical Translation": |
| | |
| | st.markdown( |
| | """ |
| | <style> |
| | .stApp { |
| | background-color: #f8f9fa; |
| | color: #333; |
| | } |
| | |
| | .main { |
| | background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| | background-size: cover; |
| | background-position: center; |
| | background-repeat: no-repeat; |
| | background-attachment: fixed; |
| | position: relative; |
| | } |
| | |
| | .main::before { |
| | content: ""; |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | width: 100%; |
| | height: 100%; |
| | background-color: rgba(248, 249, 250, 0.7); |
| | backdrop-filter: blur(5px); |
| | z-index: -1; |
| | } |
| | |
| | /* Adjust text colors for light theme */ |
| | h1, h2, h3, h4, h5, h6 { |
| | color: #333 !important; |
| | } |
| | |
| | p, span, div, label { |
| | color: #333 !important; |
| | } |
| | |
| | /* Adjust card styling for light theme */ |
| | .card { |
| | background-color: rgba(255, 255, 255, 0.7) !important; |
| | border: 1px solid rgba(147, 51, 234, 0.3) !important; |
| | } |
| | |
| | /* Adjust metric card styling for light theme */ |
| | .metric-card { |
| | background-color: rgba(255, 255, 255, 0.7) !important; |
| | } |
| | |
| | .metric-value { |
| | color: #7c3aed !important; |
| | } |
| | |
| | .metric-label { |
| | color: #333 !important; |
| | } |
| | |
| | /* Rest of your light theme adjustments */ |
| | /* ... */ |
| | </style> |
| | """, |
| | unsafe_allow_html=True |
| | ) |
| | |
| | |
| | def create_stars_html(num_stars=100): |
| | stars_html = """<div class="stars">""" |
| | for i in range(num_stars): |
| | size = random.uniform(1, 3) |
| | top = random.uniform(0, 100) |
| | left = random.uniform(0, 100) |
| | duration = random.uniform(3, 8) |
| | opacity = random.uniform(0.2, 0.8) |
| | |
| | stars_html += f""" |
| | <div class="star" style=" |
| | width: {size}px; |
| | height: {size}px; |
| | top: {top}%; |
| | left: {left}%; |
| | --duration: {duration}s; |
| | --opacity: {opacity}; |
| | "></div> |
| | """ |
| | stars_html += "</div>" |
| | return stars_html |
| |
|
| | |
| | def add_logo(logo_path='assets/logo2.png'): |
| | try: |
| | with open(logo_path, "rb") as img_file: |
| | logo_base64 = base64.b64encode(img_file.read()).decode() |
| | st.markdown( |
| | f"""<div style="position: absolute; top: 0.5rem; left: 1rem; z-index: 999;"> |
| | <img src="data:image/png;base64,{logo_base64}" width="150px"></div>""", |
| | unsafe_allow_html=True |
| | ) |
| | except FileNotFoundError: |
| | st.warning(f"Logo file not found: {logo_path}") |
| |
|
| | |
| |
|
| | |
| | st.markdown(create_stars_html(), unsafe_allow_html=True) |
| |
|
| |
|
| | |
| | with st.sidebar: |
| | st.image('assets/logo2.png', width=150) |
| | |
| | |
| | st.title("Applications") |
| | app_mode = st.radio( |
| | "Select Application", |
| | ["SAR Colorization", "SAR to Optical Translation"] |
| | ) |
| | |
| | st.session_state.app_mode = app_mode |
| | |
| | st.markdown("---") |
| | st.title("Appearance") |
| | theme = st.radio( |
| | "Select Theme", |
| | ["Dark", "Light"] |
| | ) |
| | set_app_style(st.session_state.app_mode) |
| | |
| | if theme.lower() != st.session_state.theme: |
| | st.session_state.theme = theme.lower() |
| | st.rerun() |
| | |
| | st.markdown("---") |
| |
|
| | |
| | if st.session_state.app_mode == "SAR Colorization": |
| | st.title("About") |
| | st.markdown(""" |
| | ### SAR Image Colorization |
| | |
| | This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes. |
| | |
| | #### Features: |
| | - Load pre-trained U-Net,DeepLabV3+ or SegNet models |
| | - Process single SAR images |
| | - Batch process multiple images |
| | - Visualize Pixel Level Classification with ESA WorldCover color scheme |
| | |
| | #### Developed by: |
| | Varun & Mokshyagna |
| | (NRSC, ISRO) |
| | |
| | #### Technologies: |
| | - TensorFlow/Keras |
| | - Streamlit |
| | - Rasterio |
| | - OpenCV |
| | |
| | #### Version: |
| | 1.0.0 |
| | """) |
| | |
| | elif st.session_state.app_mode == "SAR to Optical Translation": |
| | st.header("Model Configuration") |
| | |
| | |
| | unet_weights_path = "models/unet_model.h5" |
| | generator_path = "models/final_generator.keras" |
| | |
| | |
| | st.info(f"U-Net Weights Path: {unet_weights_path}") |
| | |
| | use_generator = st.checkbox("Use Generator Model for Colorization", value=True) |
| | if use_generator: |
| | st.info(f"Generator Model Path: {generator_path}") |
| | else: |
| | generator_path = None |
| | |
| | |
| | if st.button("Load Models"): |
| | with st.spinner("Loading models..."): |
| | gpu_status = setup_gpu() |
| | st.info(gpu_status) |
| | |
| | try: |
| | unet_model, generator_model = load_models(unet_weights_path, generator_path if use_generator else None) |
| | st.session_state['unet_model'] = unet_model |
| | st.session_state['generator_model'] = generator_model |
| | st.success("Models loaded successfully!") |
| | except Exception as e: |
| | st.error(f"Error loading models: {e}") |
| |
|
| | |
| | st.header("ESA WorldCover Classes") |
| | class_info = { |
| | 'Trees': [0, 100, 0], |
| | 'Shrubland': [255, 165, 0], |
| | 'Grassland': [144, 238, 144], |
| | 'Cropland': [255, 255, 0], |
| | 'Built-up': [255, 0, 0], |
| | 'Bare': [139, 69, 19], |
| | 'Snow': [255, 255, 255], |
| | 'Water': [0, 0, 255], |
| | 'Wetland': [0, 139, 139], |
| | 'Mangroves': [0, 255, 0], |
| | 'Moss': [220, 220, 220] |
| | } |
| | |
| | for class_name, color in class_info.items(): |
| | st.markdown( |
| | f'<div style="display: flex; align-items: center;">' |
| | f'<div style="width: 20px; height: 20px; background-color: rgb({color[0]}, {color[1]}, {color[2]}); margin-right: 10px;"></div>' |
| | f'<span>{class_name}</span>' |
| | f'</div>', |
| | unsafe_allow_html=True |
| | ) |
| | |
| | st.markdown("---") |
| | st.markdown("© 2025 | All Rights Reserved") |
| |
|
| | |
| | if st.session_state.app_mode == "SAR Colorization": |
| | |
| | st.markdown(""" |
| | <div style="text-align: center; margin-bottom: 2rem; position: relative; z-index: 100;"> |
| | <h1 style="color: #a78bfa; font-size: 3rem; font-weight: bold; text-shadow: 0 0 10px rgba(167, 139, 250, 0.5);"> |
| | SAR Image Colorization |
| | </h1> |
| | <p style="color: #bfdbfe; font-size: 1.2rem;"> |
| | Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning |
| | </p> |
| | </div> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | st.markdown("<div class='card'>", unsafe_allow_html=True) |
| |
|
| | |
| | |
| | tab1, tab2, tab3, tab4 = st.tabs(["📥 Load Model", "🖼️ Process Single Image", "📁 Process Multiple Images", "🔍 Sample Images"]) |
| |
|
| |
|
| | |
| | with tab1: |
| | st.markdown("<h3 style='color: #a78bfa;'>Load Segmentation Model</h3>", unsafe_allow_html=True) |
| | |
| | |
| | model_type = st.selectbox( |
| | "Select model architecture", |
| | ["U-Net", "DeepLabV3+", "SegNet"], |
| | index=0, |
| | help="Select the architecture of the model to load" |
| | ) |
| | |
| | |
| | st.session_state.segmentation.model_type = model_type.lower().replace('-', '') |
| | |
| | |
| | model_paths = { |
| | "unet": "models/unet_model.h5", |
| | "deeplabv3+": "models/deeplabv3plus_model.h5", |
| | "deeplabv3plus": "models/deeplabv3plus_model.h5", |
| | "segnet": "models/segnet_model.h5" |
| | } |
| | |
| | selected_model_path = model_paths[st.session_state.segmentation.model_type] |
| | |
| | |
| | st.info(f"Model will be loaded from: {selected_model_path}") |
| | |
| | |
| | if st.button("Load Model", key="load_model_btn"): |
| | with st.spinner(f"Loading {model_type} model..."): |
| | try: |
| | |
| | st.session_state.segmentation.load_trained_model(selected_model_path) |
| | st.session_state.model_loaded = True |
| | st.success("Model loaded successfully!") |
| | except Exception as e: |
| | st.error(f"Error loading model: {str(e)}") |
| | |
| | |
| | if st.session_state.model_loaded: |
| | st.markdown("<div class='card'>", unsafe_allow_html=True) |
| | st.markdown("<h4 style='color: #a78bfa;'>Model Information</h4>", unsafe_allow_html=True) |
| | |
| | col1, col2, col3 = st.columns(3) |
| | with col1: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | |
| | model_arch_map = { |
| | 'unet': "U-Net", |
| | 'deeplabv3plus': "DeepLabV3+", |
| | 'segnet': "SegNet" |
| | } |
| | model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown") |
| | st.markdown(f"<p class='metric-value'>{model_arch}</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Architecture</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | with col2: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-value'>11</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Land Cover Classes</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | with col3: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-value'>256 x 256</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Input Size</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | |
| | st.markdown("<h4 style='color: #a78bfa; margin-top: 20px;'>Land Cover Classes</h4>", unsafe_allow_html=True) |
| | legend_img = create_legend() |
| | st.image(legend_img, use_container_width=True) |
| | |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | else: |
| | st.info("Please load a model to continue.") |
| |
|
| | |
| | with tab2: |
| | st.markdown("<h3 style='color: #a78bfa;'>Process Single SAR Image</h3>", unsafe_allow_html=True) |
| | |
| | if not st.session_state.model_loaded: |
| | st.warning("Please load a model in the 'Load Model' tab first.") |
| | else: |
| | st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | uploaded_file = st.file_uploader( |
| | "Upload a SAR image (.tif or common image formats)", |
| | type=["tif", "tiff", "png", "jpg", "jpeg"], |
| | key="single_sar_uploader" |
| | ) |
| | |
| | with col2: |
| | |
| | ground_truth_file = st.file_uploader( |
| | "Upload ground truth (optional)", |
| | type=["tif", "tiff", "png", "jpg", "jpeg"], |
| | key="single_gt_uploader" |
| | ) |
| | |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | if uploaded_file is not None: |
| | if st.button("Process Image", key="process_single_btn"): |
| | with st.spinner("Processing image..."): |
| | |
| | try: |
| | sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True) |
| | |
| | |
| | sar_normalized = sar_data.copy() |
| | min_val = np.min(sar_normalized) |
| | max_val = np.max(sar_normalized) |
| | sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) |
| | |
| | |
| | prediction = st.session_state.segmentation.predict(sar_data) |
| | |
| | |
| | if ground_truth_file is not None: |
| | try: |
| | |
| | gt_data = st.session_state.segmentation.load_sar_data(ground_truth_file.getvalue(), is_bytes=True) |
| | |
| | |
| | if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: |
| | |
| | sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 |
| | else: |
| | |
| | sar_for_viz = sar_normalized |
| | |
| | |
| | result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( |
| | sar_for_viz, |
| | gt_data, |
| | prediction |
| | ) |
| | |
| | |
| | st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results with Ground Truth</h4>", unsafe_allow_html=True) |
| | st.image(result_buf, use_container_width=True) |
| | |
| | |
| | pred_class = np.argmax(prediction[0], axis=-1) |
| | gt_class = gt_data[:,:,0].astype(np.int32) |
| | |
| | |
| | if np.max(gt_class) > 10: |
| | |
| | gt_mapped = np.zeros_like(gt_class) |
| | class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| | for i, val in enumerate(class_values): |
| | gt_mapped[gt_class == val] = i |
| | gt_class = gt_mapped |
| | |
| | accuracy = np.mean(pred_class == gt_class) * 100 |
| | |
| | |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{accuracy:.2f}%</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Pixel Accuracy</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | |
| | btn = st.download_button( |
| | label="Download Result", |
| | data=result_buf, |
| | file_name="segmentation_result_with_gt.png", |
| | mime="image/png", |
| | key="download_single_result_with_gt" |
| | ) |
| | except Exception as e: |
| | st.error(f"Error processing ground truth: {str(e)}") |
| | |
| | result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| | st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| | st.image(result_img, use_container_width=True) |
| | |
| | |
| | btn = st.download_button( |
| | label="Download Result", |
| | data=result_img, |
| | file_name="segmentation_result.png", |
| | mime="image/png", |
| | key="download_single_result" |
| | ) |
| | else: |
| | |
| | result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| | st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| | st.image(result_img, use_container_width=True) |
| | |
| | |
| | btn = st.download_button( |
| | label="Download Result", |
| | data=result_img, |
| | file_name="segmentation_result.png", |
| | mime="image/png", |
| | key="download_single_result" |
| | ) |
| | except Exception as e: |
| | st.error(f"Error processing image: {str(e)}") |
| |
|
| | |
| | with tab3: |
| | st.markdown("<h3 style='color: #a78bfa;'>Process Multiple SAR Images</h3>", unsafe_allow_html=True) |
| | |
| | if not st.session_state.model_loaded: |
| | st.warning("Please load a model in the 'Load Model' tab first.") |
| | else: |
| | st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| | |
| | |
| | use_gt = st.checkbox("Include ground truth data", value=False) |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | uploaded_files = st.file_uploader( |
| | "Upload SAR images or a ZIP file containing images", |
| | type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| | accept_multiple_files=True, |
| | key="batch_sar_uploader" |
| | ) |
| | |
| | |
| | gt_files = None |
| | if use_gt: |
| | with col2: |
| | gt_files = st.file_uploader( |
| | "Upload ground truth images or a ZIP file (must match SAR filenames)", |
| | type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| | accept_multiple_files=True, |
| | key="batch_gt_uploader" |
| | ) |
| | st.info("Ground truth filenames should match SAR image filenames") |
| | |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | col1, col2 = st.columns([3, 1]) |
| | |
| | with col1: |
| | max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10) |
| | |
| | with col2: |
| | st.markdown("<br>", unsafe_allow_html=True) |
| | process_btn = st.button("Process Images", key="process_multi_btn") |
| | |
| | if process_btn and uploaded_files: |
| | |
| | st.session_state.processed_images = [] |
| | |
| | |
| | with st.spinner("Processing images..."): |
| | |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | |
| | sar_image_files = [] |
| | gt_image_files = {} |
| | |
| | |
| | for uploaded_file in uploaded_files: |
| | if uploaded_file.name.lower().endswith('.zip'): |
| | |
| | zip_path = os.path.join(temp_dir, uploaded_file.name) |
| | with open(zip_path, 'wb') as f: |
| | f.write(uploaded_file.getvalue()) |
| | |
| | with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| | zip_ref.extractall(os.path.join(temp_dir, 'sar')) |
| | |
| | |
| | for root, _, files in os.walk(os.path.join(temp_dir, 'sar')): |
| | for file in files: |
| | if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| | sar_image_files.append(os.path.join(root, file)) |
| | else: |
| | |
| | file_path = os.path.join(temp_dir, 'sar', uploaded_file.name) |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| | with open(file_path, 'wb') as f: |
| | f.write(uploaded_file.getvalue()) |
| | sar_image_files.append(file_path) |
| | |
| | |
| | if use_gt and gt_files: |
| | for gt_file in gt_files: |
| | if gt_file.name.lower().endswith('.zip'): |
| | |
| | zip_path = os.path.join(temp_dir, gt_file.name) |
| | with open(zip_path, 'wb') as f: |
| | f.write(gt_file.getvalue()) |
| | |
| | with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| | zip_ref.extractall(os.path.join(temp_dir, 'gt')) |
| | |
| | |
| | for root, _, files in os.walk(os.path.join(temp_dir, 'gt')): |
| | for file in files: |
| | if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| | |
| | gt_path = os.path.join(root, file) |
| | gt_image_files[os.path.basename(file)] = gt_path |
| | else: |
| | |
| | file_path = os.path.join(temp_dir, 'gt', gt_file.name) |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| | with open(file_path, 'wb') as f: |
| | f.write(gt_file.getvalue()) |
| | gt_image_files[os.path.basename(gt_file.name)] = file_path |
| | |
| | |
| | if len(sar_image_files) > max_images: |
| | st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.") |
| | sar_image_files = random.sample(sar_image_files, max_images) |
| | |
| | |
| | progress_bar = st.progress(0) |
| | |
| | |
| | if use_gt and gt_image_files: |
| | overall_accuracy = [] |
| | |
| | for i, image_path in enumerate(sar_image_files): |
| | try: |
| | |
| | progress_bar.progress((i + 1) / len(sar_image_files)) |
| | |
| | |
| | sar_data = st.session_state.segmentation.load_sar_data(image_path) |
| | |
| | |
| | sar_normalized = sar_data.copy() |
| | min_val = np.min(sar_normalized) |
| | max_val = np.max(sar_normalized) |
| | sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) |
| | |
| | |
| | prediction = st.session_state.segmentation.predict(sar_data) |
| | |
| | |
| | image_basename = os.path.basename(image_path) |
| | has_gt = image_basename in gt_image_files |
| | |
| | if has_gt and use_gt: |
| | |
| | gt_path = gt_image_files[image_basename] |
| | gt_data = st.session_state.segmentation.load_sar_data(gt_path) |
| | |
| | if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: |
| | |
| | sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 |
| | else: |
| | |
| | sar_for_viz = sar_normalized |
| |
|
| | |
| | result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( |
| | sar_for_viz, |
| | gt_data, |
| | prediction |
| | ) |
| | |
| | |
| | pred_class = np.argmax(prediction[0], axis=-1) |
| | gt_class = gt_data[:,:,0].astype(np.int32) |
| | |
| | |
| | if np.max(gt_class) > 10: |
| | |
| | gt_mapped = np.zeros_like(gt_class) |
| | class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| | for i, val in enumerate(class_values): |
| | gt_mapped[gt_class == val] = i |
| | gt_class = gt_mapped |
| | |
| | accuracy = np.mean(pred_class == gt_class) * 100 |
| | overall_accuracy.append(accuracy) |
| | |
| | |
| | st.session_state.processed_images.append({ |
| | 'filename': os.path.basename(image_path), |
| | 'result': result_buf, |
| | 'accuracy': accuracy |
| | }) |
| | else: |
| | |
| | result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| | |
| | |
| | st.session_state.processed_images.append({ |
| | 'filename': os.path.basename(image_path), |
| | 'result': result_img |
| | }) |
| | except Exception as e: |
| | st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") |
| | |
| | |
| | progress_bar.empty() |
| | |
| | |
| | if st.session_state.processed_images: |
| | st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| | |
| | |
| | if use_gt and 'overall_accuracy' in locals() and overall_accuracy: |
| | avg_accuracy = np.mean(overall_accuracy) |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{avg_accuracy:.2f}%</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Average Pixel Accuracy</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | |
| | zip_buffer = io.BytesIO() |
| | with zipfile.ZipFile(zip_buffer, 'w') as zip_file: |
| | for i, img_data in enumerate(st.session_state.processed_images): |
| | zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue()) |
| | |
| | |
| | st.download_button( |
| | label="Download All Results", |
| | data=zip_buffer.getvalue(), |
| | file_name="segmentation_results.zip", |
| | mime="application/zip", |
| | key="download_all_results" |
| | ) |
| | |
| | |
| | for i, img_data in enumerate(st.session_state.processed_images): |
| | st.markdown(f"<h5 style='color: #bfdbfe;'>Image: {img_data['filename']}</h5>", unsafe_allow_html=True) |
| | |
| | |
| | if 'accuracy' in img_data: |
| | st.markdown(f"<p style='color: #a78bfa;'>Pixel Accuracy: {img_data['accuracy']:.2f}%</p>", unsafe_allow_html=True) |
| | |
| | st.image(img_data['result'], use_container_width=True) |
| | st.markdown("<hr style='border-color: rgba(147, 51, 234, 0.3);'>", unsafe_allow_html=True) |
| | else: |
| | st.warning("No images were successfully processed.") |
| | elif process_btn: |
| | st.warning("Please upload at least one image file or ZIP archive.") |
| |
|
| | |
| | with tab4: |
| | st.markdown("<h3 style='color: #a78bfa;'>Sample Images</h3>", unsafe_allow_html=True) |
| | |
| | if not st.session_state.model_loaded: |
| | st.warning("Please load a model in the 'Load Model' tab first.") |
| | else: |
| | st.markdown("<div class='card'>", unsafe_allow_html=True) |
| | |
| | |
| | import os |
| | sample_dir = "samples/SAR" |
| | if os.path.exists(sample_dir): |
| | sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))] |
| | else: |
| | os.makedirs(sample_dir, exist_ok=True) |
| | os.makedirs("samples/OPTICAL", exist_ok=True) |
| | os.makedirs("samples/LABELS", exist_ok=True) |
| | sample_files = [] |
| | |
| | if sample_files: |
| | |
| | selected_sample = st.selectbox( |
| | "Select a sample image", |
| | sample_files, |
| | key="sample_selector" |
| | ) |
| | |
| | |
| | col1, col2, col3 = st.columns(3) |
| | |
| | with col1: |
| | st.subheader("SAR Image") |
| | sar_path = os.path.join("samples/SAR", selected_sample) |
| | display_image(sar_path) |
| |
|
| | |
| | with col2: |
| | st.subheader("Optical Image (Ground Truth)") |
| | |
| | opt_path = os.path.join("samples/OPTICAL", selected_sample) |
| | if os.path.exists(opt_path): |
| | display_image(opt_path) |
| | else: |
| | st.info("No matching optical image found") |
| | |
| | |
| | with col3: |
| | st.subheader("Label Image") |
| | samples_dir = "samples" |
| | |
| | |
| | possible_label_dirs = [ |
| | os.path.join(samples_dir, "labels"), |
| | os.path.join(samples_dir, "label"), |
| | os.path.join(samples_dir, "LABELS"), |
| | os.path.join(samples_dir, "LABEL"), |
| | os.path.join(samples_dir, "Labels"), |
| | os.path.join(samples_dir, "Label"), |
| | os.path.join(samples_dir, "gt"), |
| | os.path.join(samples_dir, "GT"), |
| | os.path.join(samples_dir, "ground_truth"), |
| | os.path.join(samples_dir, "groundtruth") |
| | ] |
| | |
| | |
| | label_path = None |
| | base_name = os.path.splitext(selected_sample)[0] |
| | |
| | |
| | for dir_path in possible_label_dirs: |
| | if not os.path.exists(dir_path): |
| | continue |
| | |
| | |
| | exact_path = os.path.join(dir_path, selected_sample) |
| | if os.path.exists(exact_path): |
| | label_path = exact_path |
| | break |
| | |
| | |
| | for ext in ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.TIF', '.TIFF', '.PNG', '.JPG', '.JPEG']: |
| | test_path = os.path.join(dir_path, base_name + ext) |
| | if os.path.exists(test_path): |
| | label_path = test_path |
| | break |
| | |
| | |
| | if not label_path: |
| | for file in os.listdir(dir_path): |
| | if os.path.splitext(file)[0].lower() == base_name.lower(): |
| | label_path = os.path.join(dir_path, file) |
| | break |
| | |
| | if label_path: |
| | break |
| | |
| | |
| | |
| | if label_path and os.path.exists(label_path): |
| | try: |
| | |
| | if label_path.lower().endswith(('.tif', '.tiff')): |
| | with rasterio.open(label_path) as src: |
| | label_data = src.read(1) |
| | |
| | |
| | colors = get_esa_colors() |
| | colored_label = np.zeros((label_data.shape[0], label_data.shape[1], 3), dtype=np.uint8) |
| | |
| | |
| | for class_idx, color in colors.items(): |
| | |
| | if np.max(label_data) > 10: |
| | |
| | class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| | for i, val in enumerate(class_values): |
| | if class_idx == i: |
| | colored_label[label_data == val] = color |
| | else: |
| | |
| | colored_label[label_data == class_idx] = color |
| | |
| | st.image(colored_label, use_container_width=True) |
| | else: |
| | |
| | display_image(label_path) |
| | except Exception as e: |
| | st.error(f"Error displaying label image: {str(e)}") |
| | |
| | display_image(label_path) |
| | else: |
| | st.info("No matching label image found") |
| |
|
| |
|
| | |
| | if st.button("Process Selected Sample", key="process_sample_btn"): |
| | with st.spinner("Processing sample image..."): |
| | try: |
| | |
| | sar_data = st.session_state.segmentation.load_sar_data(sar_path) |
| | |
| | |
| | sar_normalized = sar_data.copy() |
| | min_val = np.min(sar_normalized) |
| | max_val = np.max(sar_normalized) |
| | sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) |
| | |
| | |
| | prediction = st.session_state.segmentation.predict(sar_data) |
| | |
| | |
| | if os.path.exists(label_path): |
| | |
| | gt_data = st.session_state.segmentation.load_sar_data(label_path) |
| | |
| | |
| | if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: |
| | |
| | sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 |
| | else: |
| | |
| | sar_for_viz = sar_normalized |
| | |
| | |
| | result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( |
| | sar_for_viz, |
| | gt_data, |
| | prediction |
| | ) |
| | |
| | |
| | st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results with Ground Truth</h4>", unsafe_allow_html=True) |
| | st.image(result_buf, use_container_width=True) |
| | |
| | |
| | pred_class = np.argmax(prediction[0], axis=-1) |
| | gt_class = gt_data[:,:,0].astype(np.int32) |
| | |
| | |
| | if np.max(gt_class) > 10: |
| | |
| | gt_mapped = np.zeros_like(gt_class) |
| | class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| | for i, val in enumerate(class_values): |
| | gt_mapped[gt_class == val] = i |
| | gt_class = gt_mapped |
| | |
| | accuracy = np.mean(pred_class == gt_class) * 100 |
| | |
| | |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{accuracy:.2f}%</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Pixel Accuracy</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | |
| | btn = st.download_button( |
| | label="Download Result", |
| | data=result_buf, |
| | file_name=f"sample_result_{selected_sample}.png", |
| | mime="image/png", |
| | key="download_sample_result_with_gt" |
| | ) |
| | else: |
| | |
| | result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| | st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| | st.image(result_img, use_container_width=True) |
| | |
| | |
| | btn = st.download_button( |
| | label="Download Result", |
| | data=result_img, |
| | file_name=f"sample_result_{selected_sample}.png", |
| | mime="image/png", |
| | key="download_sample_result" |
| | ) |
| | except Exception as e: |
| | st.error(f"Error processing sample image: {str(e)}") |
| | else: |
| | st.info("No sample images found. Please add some images to the 'samples/SAR' directory.") |
| | |
| | |
| |
|
| | |
| | st.markdown("</div>", unsafe_allow_html=True) |
| |
|
| | elif st.session_state.app_mode == "SAR to Optical Translation": |
| | |
| | st.markdown(""" |
| | <div style="text-align: center; margin-bottom: 2rem; position: relative; z-index: 100;"> |
| | <h1 style="color: #a78bfa; font-size: 3rem; font-weight: bold; text-shadow: 0 0 10px rgba(167, 139, 250, 0.5);"> |
| | SAR to Optical Translation |
| | </h1> |
| | <p style="color: #bfdbfe; font-size: 1.2rem;"> |
| | Convert Synthetic Aperture Radar images to optical-like imagery using deep learning |
| | </p> |
| | </div> |
| | """, unsafe_allow_html=True) |
| | |
| | |
| | st.markdown("<div class='card'>", unsafe_allow_html=True) |
| | |
| | |
| | models_loaded = 'unet_model' in st.session_state |
| | |
| | if not models_loaded: |
| | st.warning("Please load the models from the sidebar first.") |
| | else: |
| | st.success("Models loaded successfully! You can now process SAR images.") |
| | |
| | |
| | |
| | tab1, tab2, tab3 = st.tabs(["Process Single Image", "Batch Processing", "Sample Images"]) |
| |
|
| | |
| | with tab1: |
| | st.markdown("<h3 style='color: #a78bfa;'>Upload SAR Image</h3>", unsafe_allow_html=True) |
| | |
| | |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| | uploaded_file = st.file_uploader( |
| | "Upload a SAR image (.tif or common image formats)", |
| | type=["tif", "tiff", "png", "jpg", "jpeg"], |
| | key="sar_optical_uploader" |
| | ) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | |
| | with col2: |
| | st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| | gt_file = st.file_uploader( |
| | "Upload ground truth optical image (optional)", |
| | type=["tif", "tiff", "png", "jpg", "jpeg"], |
| | key="optical_gt_uploader" |
| | ) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | if uploaded_file is not None: |
| | |
| | if st.button("Generate Optical-like Image", key="generate_optical_btn"): |
| | with st.spinner("Processing image..."): |
| | try: |
| | |
| | sar_batch, sar_image = load_sar_image(uploaded_file) |
| | |
| | if sar_batch is not None: |
| | |
| | seg_mask, colorized = process_image( |
| | sar_batch, |
| | st.session_state['unet_model'], |
| | st.session_state.get('generator_model') |
| | ) |
| | |
| | |
| | sar_rgb, colored_pred, overlay, colorized_img = visualize_results( |
| | sar_image, seg_mask, colorized |
| | ) |
| | |
| | |
| | st.header("Results") |
| | |
| | |
| | if gt_file is not None: |
| | try: |
| | |
| | with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file: |
| | tmp_file.write(gt_file.getbuffer()) |
| | tmp_file_path = tmp_file.name |
| |
|
| | try: |
| | |
| | with rasterio.open(tmp_file_path) as src: |
| | gt_image = src.read() |
| | |
| | st.info(f"Ground truth shape: {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}") |
| | |
| | if gt_image.shape[0] == 3: |
| | gt_image = np.transpose(gt_image, (1, 2, 0)) |
| | else: |
| | gt_image = src.read(1) |
| | |
| | if np.all(gt_image == 0) or np.all(gt_image == 1): |
| | st.warning("Ground truth image appears to be blank (all zeros or ones)") |
| | |
| | |
| | gt_image = np.expand_dims(gt_image, axis=-1) |
| | gt_image = np.repeat(gt_image, 3, axis=-1) |
| | except Exception as rasterio_error: |
| | st.warning(f"Rasterio failed: {str(rasterio_error)}. Trying PIL...") |
| | try: |
| | |
| | gt_image = np.array(Image.open(tmp_file_path).convert('RGB')) |
| | |
| | st.info(f"Ground truth shape (PIL): {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}") |
| | |
| | |
| | if np.all(gt_image > 250): |
| | st.warning("Ground truth image appears to be all white") |
| | except Exception as pil_error: |
| | st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") |
| | raise |
| |
|
| | |
| | os.unlink(tmp_file_path) |
| |
|
| | |
| | if gt_image.shape[:2] != (256, 256): |
| | gt_image = cv2.resize(gt_image, (256, 256)) |
| |
|
| | |
| | if gt_image.dtype != np.uint8: |
| | if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: |
| | gt_image = gt_image.astype(np.uint8) |
| | elif np.max(gt_image) <= 1.0: |
| | gt_image = (gt_image * 255).astype(np.uint8) |
| | else: |
| | |
| | gt_min, gt_max = np.min(gt_image), np.max(gt_image) |
| | gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) |
| |
|
| | |
| | |
| | fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| | |
| | |
| | axes[0].imshow(sar_rgb, cmap='gray') |
| | axes[0].set_title('Original SAR', color='white') |
| | axes[0].axis('off') |
| | |
| | |
| | axes[1].imshow(gt_image) |
| | axes[1].set_title('Ground Truth', color='white') |
| | axes[1].axis('off') |
| | |
| | |
| | axes[2].imshow(colored_pred) |
| | axes[2].set_title('Segmentation', color='white') |
| | axes[2].axis('off') |
| | |
| | |
| | if colorized_img is not None: |
| | |
| | colorized_display = (colorized_img * 0.5) + 0.5 |
| | axes[3].imshow(colorized_display) |
| | else: |
| | axes[3].imshow(overlay) |
| | axes[3].set_title('Generated Image', color='white') |
| | axes[3].axis('off') |
| | |
| | |
| | fig.patch.set_facecolor('#0a0a1f') |
| | for ax in axes: |
| | ax.set_facecolor('#0a0a1f') |
| | |
| | plt.tight_layout() |
| | |
| | |
| | st.pyplot(fig) |
| | |
| | |
| | if colorized_img is not None: |
| | |
| | colorized_norm = (colorized_img * 0.5) + 0.5 |
| | gt_norm = gt_image.astype(np.float32) / 255.0 |
| | |
| | |
| | mse = np.mean((colorized_norm - gt_norm) ** 2) |
| | psnr = 20 * np.log10(1.0 / np.sqrt(mse)) |
| | |
| | |
| | from skimage.metrics import structural_similarity as ssim |
| |
|
| | try: |
| | |
| | min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) |
| | win_size = min(7, min_dim - (min_dim % 2) + 1) |
| | |
| | ssim_value = ssim( |
| | colorized_norm, |
| | gt_norm, |
| | win_size=win_size, |
| | channel_axis=2, |
| | data_range=1.0 |
| | ) |
| | except Exception as e: |
| | st.warning(f"Could not calculate SSIM: {str(e)}") |
| | ssim_value = 0.0 |
| |
|
| | |
| | |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{psnr:.2f}</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | with col2: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{ssim_value:.4f}</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | except Exception as e: |
| | st.error(f"Error processing ground truth: {str(e)}") |
| | |
| | col1, col2, col3 = st.columns(3) |
| | |
| | with col1: |
| | st.subheader("Original SAR Image") |
| | st.image(sar_rgb, use_container_width=True) |
| | |
| | with col2: |
| | st.subheader("Predicted Segmentation") |
| | st.image(colored_pred, use_container_width=True) |
| | |
| | with col3: |
| | st.subheader("Colorized SAR") |
| | st.image(overlay, use_container_width=True) |
| | else: |
| | |
| | col1, col2, col3 = st.columns(3) |
| | |
| | with col1: |
| | st.subheader("Original SAR Image") |
| | st.image(sar_rgb, use_container_width=True) |
| | |
| | with col2: |
| | st.subheader("Predicted Segmentation") |
| | st.image(colored_pred, use_container_width=True) |
| | |
| | with col3: |
| | st.subheader("Colorized SAR") |
| | st.image(overlay, use_container_width=True) |
| | |
| | |
| | if colorized_img is not None: |
| | st.header("Translated Optical Image") |
| | |
| | colorized_display = (colorized_img * 0.5) + 0.5 |
| | |
| | |
| | fig, ax = plt.subplots(figsize=(6, 6)) |
| | ax.imshow(colorized_display) |
| | ax.axis('off') |
| | |
| | |
| | st.pyplot(fig, use_container_width=False) |
| | |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | |
| | seg_buf = io.BytesIO() |
| | plt.imsave(seg_buf, colored_pred, format='png') |
| | seg_buf.seek(0) |
| | |
| | st.download_button( |
| | label="Download Segmentation", |
| | data=seg_buf, |
| | file_name="segmentation.png", |
| | mime="image/png", |
| | key="download_seg" |
| | ) |
| | |
| | with col2: |
| | |
| | gen_buf = io.BytesIO() |
| | plt.imsave(gen_buf, colorized_display, format='png') |
| | gen_buf.seek(0) |
| | |
| | st.download_button( |
| | label="Download Optical-like Image", |
| | data=gen_buf, |
| | file_name="optical_like.png", |
| | mime="image/png", |
| | key="download_optical" |
| | ) |
| | except Exception as e: |
| | st.error(f"Error processing image: {str(e)}") |
| | |
| | |
| | with tab2: |
| | st.markdown("<h3 style='color: #a78bfa;'>Batch Process SAR Images</h3>", unsafe_allow_html=True) |
| | |
| | st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| | |
| | |
| | use_gt = st.checkbox("Include ground truth data", value=False) |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | batch_files = st.file_uploader( |
| | "Upload SAR images or a ZIP file containing images", |
| | type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| | accept_multiple_files=True, |
| | key="batch_sar_optical_uploader" |
| | ) |
| | |
| | |
| | batch_gt_files = None |
| | if use_gt: |
| | with col2: |
| | batch_gt_files = st.file_uploader( |
| | "Upload ground truth optical images or a ZIP file (must match SAR filenames)", |
| | type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| | accept_multiple_files=True, |
| | key="batch_optical_gt_uploader" |
| | ) |
| | st.info("Ground truth filenames should match SAR image filenames") |
| | |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | col1, col2 = st.columns([3, 1]) |
| | |
| | with col1: |
| | max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=5) |
| | |
| | with col2: |
| | st.markdown("<br>", unsafe_allow_html=True) |
| | batch_process_btn = st.button("Process Images", key="batch_process_btn") |
| | |
| | if batch_process_btn and batch_files: |
| | |
| | if 'batch_results' not in st.session_state: |
| | st.session_state.batch_results = [] |
| | else: |
| | st.session_state.batch_results = [] |
| | |
| | |
| | with st.spinner("Processing images..."): |
| | |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | |
| | sar_image_files = [] |
| | gt_image_files = {} |
| | |
| | |
| | for uploaded_file in batch_files: |
| | if uploaded_file.name.lower().endswith('.zip'): |
| | |
| | zip_path = os.path.join(temp_dir, uploaded_file.name) |
| | with open(zip_path, 'wb') as f: |
| | f.write(uploaded_file.getvalue()) |
| | |
| | with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| | zip_ref.extractall(os.path.join(temp_dir, 'sar')) |
| | |
| | |
| | for root, _, files in os.walk(os.path.join(temp_dir, 'sar')): |
| | for file in files: |
| | if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| | sar_image_files.append(os.path.join(root, file)) |
| | else: |
| | |
| | file_path = os.path.join(temp_dir, 'sar', uploaded_file.name) |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| | with open(file_path, 'wb') as f: |
| | f.write(uploaded_file.getvalue()) |
| | sar_image_files.append(file_path) |
| | |
| | |
| | if use_gt and batch_gt_files: |
| | for gt_file in batch_gt_files: |
| | if gt_file.name.lower().endswith('.zip'): |
| | |
| | zip_path = os.path.join(temp_dir, gt_file.name) |
| | with open(zip_path, 'wb') as f: |
| | f.write(gt_file.getvalue()) |
| | |
| | with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| | zip_ref.extractall(os.path.join(temp_dir, 'gt')) |
| | |
| | |
| | for root, _, files in os.walk(os.path.join(temp_dir, 'gt')): |
| | for file in files: |
| | if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| | |
| | gt_path = os.path.join(root, file) |
| | gt_image_files[os.path.basename(file)] = gt_path |
| | else: |
| | |
| | file_path = os.path.join(temp_dir, 'gt', gt_file.name) |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| | with open(file_path, 'wb') as f: |
| | f.write(gt_file.getvalue()) |
| | gt_image_files[os.path.basename(gt_file.name)] = file_path |
| | |
| | |
| | if len(sar_image_files) > max_images: |
| | st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.") |
| | sar_image_files = random.sample(sar_image_files, max_images) |
| | |
| | |
| | progress_bar = st.progress(0) |
| | |
| | |
| | if use_gt and gt_image_files: |
| | overall_psnr = [] |
| | overall_ssim = [] |
| | |
| | for i, image_path in enumerate(sar_image_files): |
| | try: |
| | |
| | progress_bar.progress((i + 1) / len(sar_image_files)) |
| | |
| | |
| | with open(image_path, 'rb') as f: |
| | file_bytes = f.read() |
| | |
| | sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes)) |
| | |
| | if sar_batch is not None: |
| | |
| | seg_mask, colorized = process_image( |
| | sar_batch, |
| | st.session_state['unet_model'], |
| | st.session_state.get('generator_model') |
| | ) |
| | |
| | |
| | sar_rgb, colored_pred, overlay, colorized_img = visualize_results( |
| | sar_image, seg_mask, colorized |
| | ) |
| | |
| | |
| | image_basename = os.path.basename(image_path) |
| | has_gt = image_basename in gt_image_files |
| | |
| | if has_gt and use_gt: |
| | |
| | gt_path = gt_image_files[image_basename] |
| | try: |
| | |
| | with rasterio.open(gt_path) as src: |
| | gt_image = src.read() |
| | |
| | if gt_image.shape[0] == 3: |
| | gt_image = np.transpose(gt_image, (1, 2, 0)) |
| | else: |
| | gt_image = src.read(1) |
| | |
| | |
| | gt_image = np.expand_dims(gt_image, axis=-1) |
| | gt_image = np.repeat(gt_image, 3, axis=-1) |
| | except Exception as rasterio_error: |
| | try: |
| | |
| | gt_image = np.array(Image.open(gt_path).convert('RGB')) |
| | except Exception as pil_error: |
| | st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") |
| | raise |
| |
|
| | |
| | if gt_image.shape[:2] != (256, 256): |
| | gt_image = cv2.resize(gt_image, (256, 256)) |
| |
|
| | |
| | if gt_image.dtype != np.uint8: |
| | if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: |
| | gt_image = gt_image.astype(np.uint8) |
| | elif np.max(gt_image) <= 1.0: |
| | gt_image = (gt_image * 255).astype(np.uint8) |
| | else: |
| | |
| | gt_min, gt_max = np.min(gt_image), np.max(gt_image) |
| | gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) |
| |
|
| | |
| | |
| | fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| | |
| | |
| | axes[0].imshow(sar_rgb, cmap='gray') |
| | axes[0].set_title('Original SAR', color='white') |
| | axes[0].axis('off') |
| | |
| | |
| | axes[1].imshow(gt_image) |
| | axes[1].set_title('Ground Truth', color='white') |
| | axes[1].axis('off') |
| | |
| | |
| | axes[2].imshow(colored_pred) |
| | axes[2].set_title('Segmentation', color='white') |
| | axes[2].axis('off') |
| | |
| | |
| | if colorized_img is not None: |
| | |
| | colorized_display = (colorized_img * 0.5) + 0.5 |
| | axes[3].imshow(colorized_display) |
| | else: |
| | axes[3].imshow(overlay) |
| | axes[3].set_title('Generated Image', color='white') |
| | axes[3].axis('off') |
| | |
| | |
| | fig.patch.set_facecolor('#0a0a1f') |
| | for ax in axes: |
| | ax.set_facecolor('#0a0a1f') |
| | |
| | plt.tight_layout() |
| | |
| | |
| | result_buf = io.BytesIO() |
| | plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') |
| | result_buf.seek(0) |
| | plt.close(fig) |
| | |
| | |
| | metrics = {'psnr': 0.0, 'ssim': 0.0} |
| | if colorized_img is not None: |
| | try: |
| | |
| | colorized_norm = (colorized_img * 0.5) + 0.5 |
| | gt_norm = gt_image.astype(np.float32) / 255.0 |
| | |
| | |
| | mse = np.mean((colorized_norm - gt_norm) ** 2) |
| | if mse > 0: |
| | psnr = 20 * np.log10(1.0 / np.sqrt(mse)) |
| | metrics['psnr'] = psnr |
| | overall_psnr.append(psnr) |
| | |
| | |
| | from skimage.metrics import structural_similarity as ssim |
| | |
| | |
| | min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) |
| | win_size = min(7, min_dim - (min_dim % 2) + 1) |
| | |
| | ssim_value = ssim( |
| | colorized_norm, |
| | gt_norm, |
| | win_size=win_size, |
| | channel_axis=2, |
| | data_range=1.0 |
| | ) |
| | metrics['ssim'] = ssim_value |
| | overall_ssim.append(ssim_value) |
| | except Exception as e: |
| | st.warning(f"Could not calculate metrics for {os.path.basename(image_path)}: {str(e)}") |
| |
|
| | |
| | |
| | gen_buf = io.BytesIO() |
| | if colorized_img is not None: |
| | plt.imsave(gen_buf, colorized_display, format='png') |
| | else: |
| | plt.imsave(gen_buf, overlay, format='png') |
| | gen_buf.seek(0) |
| | |
| | |
| | st.session_state.batch_results.append({ |
| | 'filename': os.path.basename(image_path), |
| | 'result': result_buf, |
| | 'generated': gen_buf, |
| | 'metrics': metrics |
| | }) |
| | else: |
| | |
| | fig, axes = plt.subplots(1, 3, figsize=(12, 4)) |
| | |
| | |
| | axes[0].imshow(sar_rgb, cmap='gray') |
| | axes[0].set_title('Original SAR', color='white') |
| | axes[0].axis('off') |
| | |
| | |
| | axes[1].imshow(colored_pred) |
| | axes[1].set_title('Segmentation', color='white') |
| | axes[1].axis('off') |
| | |
| | |
| | if colorized_img is not None: |
| | |
| | colorized_display = (colorized_img * 0.5) + 0.5 |
| | axes[2].imshow(colorized_display) |
| | else: |
| | axes[2].imshow(overlay) |
| | axes[2].set_title('Generated Image', color='white') |
| | axes[2].axis('off') |
| | |
| | |
| | fig.patch.set_facecolor('#0a0a1f') |
| | for ax in axes: |
| | ax.set_facecolor('#0a0a1f') |
| | |
| | plt.tight_layout() |
| | |
| | |
| | result_buf = io.BytesIO() |
| | plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') |
| | result_buf.seek(0) |
| | plt.close(fig) |
| | |
| | |
| | gen_buf = io.BytesIO() |
| | if colorized_img is not None: |
| | plt.imsave(gen_buf, colorized_display, format='png') |
| | else: |
| | plt.imsave(gen_buf, overlay, format='png') |
| | gen_buf.seek(0) |
| | |
| | |
| | st.session_state.batch_results.append({ |
| | 'filename': os.path.basename(image_path), |
| | 'result': result_buf, |
| | 'generated': gen_buf |
| | }) |
| | except Exception as e: |
| | st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") |
| | |
| | |
| | progress_bar.empty() |
| | |
| | |
| | if st.session_state.batch_results: |
| | st.markdown("<h4 style='color: #a78bfa;'>Translation Results</h4>", unsafe_allow_html=True) |
| | |
| | |
| | if use_gt and 'overall_psnr' in locals() and overall_psnr: |
| | avg_psnr = np.mean(overall_psnr) |
| | avg_ssim = np.mean(overall_ssim) |
| | |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{avg_psnr:.2f}</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Average PSNR (dB)</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | with col2: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{avg_ssim:.4f}</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>Average SSIM</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | |
| | zip_buffer = io.BytesIO() |
| | with zipfile.ZipFile(zip_buffer, 'w') as zip_file: |
| | for i, result in enumerate(st.session_state.batch_results): |
| | |
| | zip_file.writestr(f"result_{i+1}_{result['filename']}.png", result['result'].getvalue()) |
| | |
| | zip_file.writestr(f"generated_{i+1}_{result['filename']}.png", result['generated'].getvalue()) |
| | |
| | |
| | st.download_button( |
| | label="Download All Results", |
| | data=zip_buffer.getvalue(), |
| | file_name="translation_results.zip", |
| | mime="application/zip", |
| | key="download_all_translation_results" |
| | ) |
| | |
| | |
| | for i, result in enumerate(st.session_state.batch_results): |
| | st.markdown(f"<h5 style='color: #bfdbfe;'>Image: {result['filename']}</h5>", unsafe_allow_html=True) |
| | |
| | |
| | if 'metrics' in result: |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | if 'psnr' in result['metrics']: |
| | st.markdown(f"<p class='metric-value'>{result['metrics']['psnr']:.2f}</p>", unsafe_allow_html=True) |
| | else: |
| | st.markdown("<p class='metric-value'>N/A</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | with col2: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | if 'ssim' in result['metrics']: |
| | st.markdown(f"<p class='metric-value'>{result['metrics']['ssim']:.4f}</p>", unsafe_allow_html=True) |
| | else: |
| | st.markdown("<p class='metric-value'>N/A</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| |
|
| | |
| | st.image(result['result'], use_container_width=True) |
| | |
| | |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | st.download_button( |
| | label="Download Visualization", |
| | data=result['result'].getvalue(), |
| | file_name=f"result_{result['filename']}.png", |
| | mime="image/png", |
| | key=f"download_viz_{i}" |
| | ) |
| | |
| | with col2: |
| | st.download_button( |
| | label="Download Generated Image", |
| | data=result['generated'].getvalue(), |
| | file_name=f"generated_{result['filename']}.png", |
| | mime="image/png", |
| | key=f"download_gen_{i}" |
| | ) |
| | |
| | st.markdown("<hr style='border-color: rgba(147, 51, 234, 0.3);'>", unsafe_allow_html=True) |
| | else: |
| | st.warning("No images were successfully processed.") |
| | elif batch_process_btn: |
| | st.warning("Please upload at least one image file or ZIP archive.") |
| | |
| | with tab3: |
| | st.markdown("<h3 style='color: #a78bfa;'>Sample Images</h3>", unsafe_allow_html=True) |
| | st.markdown("<div class='card'>", unsafe_allow_html=True) |
| | |
| | |
| | import os |
| | sample_dir = "samples/SAR" |
| | if os.path.exists(sample_dir): |
| | sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))] |
| | else: |
| | os.makedirs(sample_dir, exist_ok=True) |
| | os.makedirs("samples/OPTICAL", exist_ok=True) |
| | os.makedirs("samples/LABELS", exist_ok=True) |
| | sample_files = [] |
| | |
| | if sample_files and 'unet_model' in st.session_state: |
| | |
| | selected_sample = st.selectbox( |
| | "Select a sample image", |
| | sample_files, |
| | key="optical_sample_selector" |
| | ) |
| | |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | st.subheader("SAR Image") |
| | sar_path = os.path.join("samples/SAR", selected_sample) |
| | display_image(sar_path) |
| |
|
| |
|
| | with col2: |
| | st.subheader("Optical Image (Ground Truth)") |
| | |
| | opt_path = os.path.join("samples/OPTICAL", selected_sample) |
| | if os.path.exists(opt_path): |
| | display_image(opt_path) |
| | else: |
| | st.info("No matching optical image found") |
| | |
| | |
| | if st.button("Generate Optical-like Image", key="process_optical_sample_btn"): |
| | with st.spinner("Processing sample image..."): |
| | try: |
| | |
| | with open(sar_path, 'rb') as f: |
| | file_bytes = f.read() |
| | |
| | sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes)) |
| | |
| | if sar_batch is not None: |
| | |
| | seg_mask, colorized = process_image( |
| | sar_batch, |
| | st.session_state['unet_model'], |
| | st.session_state.get('generator_model') |
| | ) |
| | |
| | |
| | sar_rgb, colored_pred, overlay, colorized_img = visualize_results( |
| | sar_image, seg_mask, colorized |
| | ) |
| | |
| | |
| | has_gt = os.path.exists(opt_path) |
| | |
| | if has_gt: |
| | |
| | try: |
| | |
| | with rasterio.open(opt_path) as src: |
| | gt_image = src.read() |
| | |
| | if gt_image.shape[0] == 3: |
| | gt_image = np.transpose(gt_image, (1, 2, 0)) |
| | else: |
| | gt_image = src.read(1) |
| | |
| | |
| | |
| | gt_image = np.expand_dims(gt_image, axis=-1) |
| | gt_image = np.repeat(gt_image, 3, axis=-1) |
| | except Exception as rasterio_error: |
| | try: |
| | |
| | gt_image = np.array(Image.open(opt_path).convert('RGB')) |
| | except Exception as pil_error: |
| | st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") |
| | raise |
| |
|
| | |
| | if gt_image.shape[:2] != (256, 256): |
| | gt_image = cv2.resize(gt_image, (256, 256)) |
| |
|
| | |
| | if gt_image.dtype != np.uint8: |
| | if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: |
| | gt_image = gt_image.astype(np.uint8) |
| | elif np.max(gt_image) <= 1.0: |
| | gt_image = (gt_image * 255).astype(np.uint8) |
| | else: |
| | |
| | gt_min, gt_max = np.min(gt_image), np.max(gt_image) |
| | gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) |
| | |
| | |
| | fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| | |
| | |
| | axes[0].imshow(sar_rgb, cmap='gray') |
| | axes[0].set_title('Original SAR', color='white') |
| | axes[0].axis('off') |
| | |
| | |
| | axes[1].imshow(gt_image) |
| | axes[1].set_title('Ground Truth', color='white') |
| | axes[1].axis('off') |
| | |
| | |
| | axes[2].imshow(colored_pred) |
| | axes[2].set_title('Segmentation', color='white') |
| | axes[2].axis('off') |
| | |
| | |
| | if colorized_img is not None: |
| | |
| | colorized_display = (colorized_img * 0.5) + 0.5 |
| | axes[3].imshow(colorized_display) |
| | else: |
| | axes[3].imshow(overlay) |
| | axes[3].set_title('Generated Image', color='white') |
| | axes[3].axis('off') |
| | |
| | |
| | fig.patch.set_facecolor('#0a0a1f') |
| | for ax in axes: |
| | ax.set_facecolor('#0a0a1f') |
| | |
| | plt.tight_layout() |
| | |
| | |
| | st.pyplot(fig) |
| | |
| | |
| | if colorized_img is not None: |
| | |
| | colorized_norm = (colorized_img * 0.5) + 0.5 |
| | gt_norm = gt_image.astype(np.float32) / 255.0 |
| | |
| | |
| | mse = np.mean((colorized_norm - gt_norm) ** 2) |
| | psnr = 20 * np.log10(1.0 / np.sqrt(mse)) |
| | |
| | |
| | from skimage.metrics import structural_similarity as ssim |
| |
|
| | try: |
| | |
| | min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) |
| | win_size = min(7, min_dim - (min_dim % 2) + 1) |
| | |
| | ssim_value = ssim( |
| | colorized_norm, |
| | gt_norm, |
| | win_size=win_size, |
| | channel_axis=2, |
| | data_range=1.0 |
| | ) |
| | except Exception as e: |
| | st.warning(f"Could not calculate SSIM: {str(e)}") |
| | ssim_value = 0.0 |
| | |
| | |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{psnr:.2f}</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | |
| | with col2: |
| | st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| | st.markdown(f"<p class='metric-value'>{ssim_value:.4f}</p>", unsafe_allow_html=True) |
| | st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True) |
| | st.markdown("</div>", unsafe_allow_html=True) |
| | else: |
| | |
| | col1, col2, col3 = st.columns(3) |
| | |
| | with col1: |
| | st.subheader("Original SAR Image") |
| | st.image(sar_rgb, use_container_width=True) |
| | |
| | with col2: |
| | st.subheader("Predicted Segmentation") |
| | st.image(colored_pred, use_container_width=True) |
| | |
| | with col3: |
| | st.subheader("Colorized SAR") |
| | if colorized_img is not None: |
| | |
| | colorized_display = (colorized_img * 0.5) + 0.5 |
| | st.image(colorized_display, use_container_width=True) |
| | else: |
| | st.image(overlay, use_container_width=True) |
| | |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | |
| | seg_buf = io.BytesIO() |
| | plt.imsave(seg_buf, colored_pred, format='png') |
| | seg_buf.seek(0) |
| | |
| | st.download_button( |
| | label="Download Segmentation", |
| | data=seg_buf, |
| | file_name=f"sample_segmentation_{selected_sample}.png", |
| | mime="image/png", |
| | key="download_sample_seg" |
| | ) |
| | |
| | with col2: |
| | |
| | gen_buf = io.BytesIO() |
| | if colorized_img is not None: |
| | plt.imsave(gen_buf, (colorized_img * 0.5) + 0.5, format='png') |
| | else: |
| | plt.imsave(gen_buf, overlay, format='png') |
| | gen_buf.seek(0) |
| | |
| | st.download_button( |
| | label="Download Optical-like Image", |
| | data=gen_buf, |
| | file_name=f"sample_optical_{selected_sample}.png", |
| | mime="image/png", |
| | key="download_sample_optical" |
| | ) |
| | except Exception as e: |
| | st.error(f"Error processing sample image: {str(e)}") |
| | elif not sample_files: |
| | st.info("No sample images found. Please add some images to the 'samples/SAR' directory.") |
| | else: |
| | st.warning("Please load the models from the sidebar first.") |
| | |
| |
|
| |
|
| | |
| | st.markdown("</div>", unsafe_allow_html=True) |
| |
|
| | |
| | st.markdown(""" |
| | <div style="text-align: center; margin-top: 2rem; padding: 1rem; background-color: rgba(0, 0, 0, 0.3); border-radius: 0.5rem;"> |
| | <p style="color: #bfdbfe; font-size: 0.9rem;"> |
| | SAR IMAGE PROCESSING | VARUN & MOKSHYAGNA |
| | </p> |
| | </div> |
| | """, unsafe_allow_html=True) |
| | |
| |
|
| | def create_stars_html(): |
| | """Create twinkling stars effect for background""" |
| | stars_html = """ |
| | <div class="stars"> |
| | """ |
| | for i in range(100): |
| | size = random.uniform(1, 3) |
| | top = random.uniform(0, 100) |
| | left = random.uniform(0, 100) |
| | duration = random.uniform(3, 8) |
| | opacity = random.uniform(0.2, 0.8) |
| | |
| | stars_html += f""" |
| | <div class="star" style=" |
| | width: {size}px; |
| | height: {size}px; |
| | top: {top}%; |
| | left: {left}%; |
| | --duration: {duration}s; |
| | --opacity: {opacity}; |
| | "></div> |
| | """ |
| | stars_html += "</div>" |
| | return stars_html |
| |
|
| | |
| | |
| | def setup_page_style(): |
| | """Set up the page style with CSS based on selected theme""" |
| | |
| | |
| | common_css = """ |
| | /* Create twinkling stars effect */ |
| | @keyframes twinkle { |
| | 0%, 100% { opacity: 0.2; } |
| | 50% { opacity: 1; } |
| | } |
| | |
| | .stars { |
| | position: fixed; |
| | top: 0; |
| | left: 0; |
| | width: 100%; |
| | height: 100%; |
| | pointer-events: none; |
| | z-index: -1; |
| | } |
| | |
| | .star { |
| | position: absolute; |
| | background-color: white; |
| | border-radius: 50%; |
| | animation: twinkle var(--duration) infinite; |
| | opacity: var(--opacity); |
| | } |
| | |
| | /* Tab styling */ |
| | .stTabs [data-baseweb="tab-list"] { |
| | gap: 24px !important; |
| | border-radius: 0.5rem; |
| | padding: 0.8rem; |
| | margin-bottom: 3rem !important; |
| | display: flex; |
| | justify-content: center !important; |
| | width: 100%; |
| | } |
| | |
| | .stTabs [data-baseweb="tab"] { |
| | height: 5rem !important; |
| | white-space: pre-wrap; |
| | border-radius: 0.5rem; |
| | font-weight: 600 !important; |
| | font-size: 1.6rem !important; |
| | padding: 0 25px !important; |
| | display: flex; |
| | align-items: center; |
| | justify-content: center; |
| | min-width: 200px !important; |
| | } |
| | |
| | /* Add more space between tab panels */ |
| | .stTabs [data-baseweb="tab-panel"] { |
| | padding-top: 3rem !important; |
| | padding-bottom: 3rem !important; |
| | } |
| | |
| | /* Button styling */ |
| | .stButton>button { |
| | border: none; |
| | border-radius: 0.5rem; |
| | padding: 0.8rem 1.5rem !important; |
| | font-weight: 500; |
| | font-size: 1.2rem !important; |
| | margin-top: 1.5rem !important; |
| | margin-bottom: 1.5rem !important; |
| | } |
| | |
| | /* Spacing */ |
| | .element-container { |
| | margin-bottom: 2.5rem !important; |
| | } |
| | |
| | h3 { |
| | margin-top: 3rem !important; |
| | margin-bottom: 2rem !important; |
| | font-size: 1.8rem !important; |
| | } |
| | |
| | h4 { |
| | margin-top: 2.5rem !important; |
| | margin-bottom: 1.5rem !important; |
| | font-size: 1.5rem !important; |
| | } |
| | |
| | h5 { |
| | margin-top: 2rem !important; |
| | margin-bottom: 1.5rem !important; |
| | font-size: 1.3rem !important; |
| | } |
| | |
| | img { |
| | margin-top: 1.5rem !important; |
| | margin-bottom: 2.5rem !important; |
| | } |
| | |
| | .stProgress > div { |
| | margin-top: 2rem !important; |
| | margin-bottom: 2rem !important; |
| | } |
| | |
| | .stSlider { |
| | padding-top: 1.5rem !important; |
| | padding-bottom: 2.5rem !important; |
| | } |
| | |
| | .row-widget { |
| | margin-top: 1.5rem !important; |
| | margin-bottom: 2.5rem !important; |
| | } |
| | """ |
| | |
| | |
| | dark_css = """ |
| | .stApp { |
| | background-color: #0a0a1f; |
| | color: white; |
| | } |
| | |
| | .main { |
| | background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| | background-size: cover; |
| | background-position: center; |
| | background-repeat: no-repeat; |
| | background-attachment: fixed; |
| | position: relative; |
| | } |
| | |
| | .main::before { |
| | content: ""; |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | width: 100%; |
| | height: 100%; |
| | background-color: rgba(10, 10, 31, 0.7); |
| | backdrop-filter: blur(5px); |
| | z-index: -1; |
| | } |
| | |
| | /* Title styling */ |
| | h1.title { |
| | background: linear-gradient(to right, #a78bfa, #ec4899, #3b82f6); |
| | -webkit-background-clip: text; |
| | -webkit-text-fill-color: transparent; |
| | background-clip: text; |
| | color: transparent; |
| | font-size: 3rem !important; |
| | font-weight: bold !important; |
| | text-align: center !important; |
| | margin-bottom: 0.5rem !important; |
| | display: block !important; |
| | position: relative !important; |
| | z-index: 10 !important; |
| | } |
| | |
| | p.subtitle { |
| | color: #bfdbfe !important; |
| | font-size: 1.2rem !important; |
| | text-align: center !important; |
| | margin-bottom: 2rem !important; |
| | position: relative !important; |
| | z-index: 10 !important; |
| | } |
| | |
| | /* Tab styling */ |
| | .stTabs [data-baseweb="tab-list"] { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | } |
| | |
| | .stTabs [data-baseweb="tab"] { |
| | background-color: transparent; |
| | color: white; |
| | } |
| | |
| | .stTabs [aria-selected="true"] { |
| | background-color: rgba(147, 51, 234, 0.5) !important; |
| | transform: scale(1.05); |
| | transition: all 0.2s ease; |
| | } |
| | |
| | /* Card and box styling */ |
| | .upload-box { |
| | border: 2px dashed rgba(147, 51, 234, 0.5); |
| | border-radius: 1rem; |
| | padding: 4rem !important; |
| | text-align: center; |
| | margin-bottom: 3rem !important; |
| | } |
| | |
| | .card { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | border: 1px solid rgba(147, 51, 234, 0.3); |
| | border-radius: 1rem; |
| | padding: 2.5rem !important; |
| | backdrop-filter: blur(10px); |
| | margin-bottom: 3rem !important; |
| | } |
| | |
| | /* Button styling */ |
| | .stButton>button { |
| | background: linear-gradient(to right, #7c3aed, #2563eb); |
| | color: white; |
| | } |
| | |
| | .stButton>button:hover { |
| | background: linear-gradient(to right, #6d28d9, #1d4ed8); |
| | } |
| | |
| | .download-btn { |
| | background-color: #2563eb !important; |
| | } |
| | |
| | .stSlider>div>div>div { |
| | background-color: #7c3aed; |
| | } |
| | |
| | /* Metrics styling */ |
| | .plot-container { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | border-radius: 1rem; |
| | padding: 2rem !important; |
| | margin-bottom: 3rem !important; |
| | } |
| | |
| | .metric-card { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | border: 1px solid rgba(147, 51, 234, 0.3); |
| | border-radius: 0.5rem; |
| | padding: 1.5rem !important; |
| | text-align: center; |
| | margin-bottom: 2rem !important; |
| | } |
| | |
| | .metric-value { |
| | font-size: 2rem !important; |
| | font-weight: bold; |
| | color: #a78bfa; |
| | } |
| | |
| | .metric-label { |
| | font-size: 1.1rem !important; |
| | color: #bfdbfe; |
| | } |
| | |
| | /* Form elements */ |
| | .stFileUploader > div { |
| | background-color: rgba(0, 0, 0, 0.3) !important; |
| | border: 1px dashed rgba(147, 51, 234, 0.5) !important; |
| | padding: 2rem !important; |
| | margin-bottom: 2rem !important; |
| | } |
| | |
| | .stSelectbox > div > div { |
| | background-color: rgba(0, 0, 0, 0.3) !important; |
| | border: 1px solid rgba(147, 51, 234, 0.3) !important; |
| | } |
| | """ |
| | |
| | |
| | |
| | |
| | light_css = """ |
| | /* Keep the same dark background */ |
| | .stApp { |
| | background-color: #0a0a1f; |
| | } |
| | |
| | .main { |
| | background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| | background-size: cover; |
| | background-position: center; |
| | background-repeat: no-repeat; |
| | background-attachment: fixed; |
| | position: relative; |
| | } |
| | |
| | .main::before { |
| | content: ""; |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | width: 100%; |
| | height: 100%; |
| | background-color: rgba(10, 10, 31, 0.7); |
| | backdrop-filter: blur(5px); |
| | z-index: -1; |
| | } |
| | |
| | /* Make all text white/light */ |
| | p, span, label, div, h1, h2, h3, h4, h5, h6, li { |
| | color: white !important; |
| | } |
| | |
| | /* Title styling - brighter gradient for better visibility */ |
| | h1.title { |
| | background: linear-gradient(to right, #d8b4fe, #f9a8d4, #93c5fd); |
| | -webkit-background-clip: text; |
| | -webkit-text-fill-color: transparent; |
| | background-clip: text; |
| | color: transparent; |
| | font-size: 3rem !important; |
| | font-weight: bold !important; |
| | text-align: center !important; |
| | margin-bottom: 0.5rem !important; |
| | display: block !important; |
| | position: relative !important; |
| | z-index: 10 !important; |
| | } |
| | |
| | p.subtitle { |
| | color: #e0e7ff !important; /* Lighter purple */ |
| | font-size: 1.2rem !important; |
| | text-align: center !important; |
| | margin-bottom: 2rem !important; |
| | position: relative !important; |
| | z-index: 10 !important; |
| | } |
| | |
| | /* Tab styling - brighter for better visibility */ |
| | .stTabs [data-baseweb="tab-list"] { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | } |
| | |
| | .stTabs [data-baseweb="tab"] { |
| | background-color: transparent; |
| | color: white !important; |
| | } |
| | |
| | .stTabs [aria-selected="true"] { |
| | background-color: rgba(167, 139, 250, 0.5) !important; /* Brighter purple */ |
| | transform: scale(1.05); |
| | transition: all 0.2s ease; |
| | } |
| | |
| | /* Card and box styling - brighter borders */ |
| | .upload-box { |
| | border: 2px dashed rgba(167, 139, 250, 0.7); /* Brighter purple */ |
| | border-radius: 1rem; |
| | padding: 4rem !important; |
| | text-align: center; |
| | margin-bottom: 3rem !important; |
| | } |
| | |
| | .card { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */ |
| | border-radius: 1rem; |
| | padding: 2.5rem !important; |
| | backdrop-filter: blur(10px); |
| | margin-bottom: 3rem !important; |
| | } |
| | |
| | /* Button styling - brighter gradient */ |
| | .stButton>button { |
| | background: linear-gradient(to right, #a78bfa, #60a5fa); |
| | color: white; |
| | } |
| | |
| | .stButton>button:hover { |
| | background: linear-gradient(to right, #8b5cf6, #3b82f6); |
| | } |
| | |
| | .download-btn { |
| | background-color: #60a5fa !important; |
| | } |
| | |
| | .stSlider>div>div>div { |
| | background-color: #a78bfa; |
| | } |
| | |
| | /* Metrics styling - brighter accents */ |
| | .plot-container { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | border-radius: 1rem; |
| | padding: 2rem !important; |
| | margin-bottom: 3rem !important; |
| | } |
| | |
| | .metric-card { |
| | background-color: rgba(0, 0, 0, 0.3); |
| | border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */ |
| | border-radius: 0.5rem; |
| | padding: 1.5rem !important; |
| | text-align: center; |
| | margin-bottom: 2rem !important; |
| | } |
| | |
| | .metric-value { |
| | font-size: 2rem !important; |
| | font-weight: bold; |
| | color: #d8b4fe; /* Brighter purple */ |
| | } |
| | |
| | .metric-label { |
| | font-size: 1.1rem !important; |
| | color: #e0e7ff; /* Lighter purple */ |
| | } |
| | |
| | /* Form elements - brighter borders */ |
| | .stFileUploader > div { |
| | background-color: rgba(0, 0, 0, 0.3) !important; |
| | border: 1px dashed rgba(167, 139, 250, 0.7) !important; /* Brighter purple */ |
| | padding: 2rem !important; |
| | margin-bottom: 2rem !important; |
| | } |
| | |
| | .stSelectbox > div > div { |
| | background-color: rgba(0, 0, 0, 0.3) !important; |
| | border: 1px solid rgba(167, 139, 250, 0.5) !important; /* Brighter purple */ |
| | } |
| | |
| | /* Make sure all text inputs have white text */ |
| | input, textarea { |
| | color: white !important; |
| | } |
| | |
| | /* Ensure sidebar text is white */ |
| | .css-1d391kg, .css-1lcbmhc { |
| | color: white !important; |
| | } |
| | |
| | /* Make sure plot text is visible on dark background */ |
| | .js-plotly-plot .plotly .main-svg text { |
| | fill: white !important; |
| | } |
| | |
| | /* Keep stars visible in light theme */ |
| | .star { |
| | background-color: white; |
| | opacity: 0.8; |
| | } |
| | |
| | /* Make sure all streamlit elements have white text */ |
| | .stMarkdown, .stText, .stCode, .stTextInput, .stTextArea, .stSelectbox, .stMultiselect, |
| | .stSlider, .stCheckbox, .stRadio, .stNumber, .stDate, .stTime, .stDateInput, .stTimeInput { |
| | color: white !important; |
| | } |
| | |
| | /* Ensure dropdown options are visible */ |
| | .stSelectbox ul li { |
| | color: black !important; |
| | } |
| | """ |
| |
|
| | |
| | |
| | |
| | if st.session_state.theme == "dark": |
| | st.markdown(f"<style>{common_css}{dark_css}</style>", unsafe_allow_html=True) |
| | else: |
| | st.markdown(f"<style>{common_css}{light_css}</style>", unsafe_allow_html=True) |
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | |
| | setup_page_style() |
| | |
| | |
| | setup_gpu() |
| | |
| | |
| | if 'app_mode' not in st.session_state: |
| | st.session_state.app_mode = "SAR Colorization" |
| | if 'model_loaded' not in st.session_state: |
| | st.session_state.model_loaded = False |
| | if 'segmentation' not in st.session_state: |
| | st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) |
| | if 'processed_images' not in st.session_state: |
| | st.session_state.processed_images = [] |
| |
|
| |
|
| |
|
| |
|