diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,1391 +1,1397 @@ -import streamlit as st -import tensorflow as tf -from tensorflow.keras.models import Model -from tensorflow.keras.layers import * -from tensorflow.keras.optimizers import Adam -import cv2 -import os -import numpy as np -import matplotlib.pyplot as plt -from PIL import Image -import io -import base64 -import tempfile -import zipfile -import random -import time -import rasterio -from rasterio.errors import RasterioIOError -import h5py -import json - - -# Set page configuration -st.set_page_config( - page_title="SAR Image Pixel Level Classification", - page_icon="🛰", - layout="wide" -) - - -with st.sidebar: - st.image('assets/logo2.png', width=150) - st.title("About") - st.markdown(""" - ### SAR Image Colorization - - This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes. - - #### Features: - - Load pre-trained U-Net,DeepLabV3+ or SegNet models - - Process single SAR images - - Batch process multiple images - - Visualize Pixel Level Classification with ESA WorldCover color scheme - - #### Developed by: - Varun & Mokshyagna - - #### Technologies: - - TensorFlow/Keras - - Streamlit - - Rasterio - - OpenCV - - #### Version: - 1.0.0 - """) - - st.markdown("---") - st.markdown("© 2025 | All Rights Reserved") - - -def load_model_with_weights(model_path): - """Load a model directly from an H5 file, preserving the original architecture""" - # If model_path is a filename without path, prepend the models directory - if not os.path.dirname(model_path) and not model_path.startswith('models/'): - model_path = os.path.join('models', os.path.basename(model_path)) - - try: - # Try to load the complete model (architecture + weights) - model = tf.keras.models.load_model(model_path, compile=False) - print("Loaded complete model with architecture") - return model - except Exception as e: - print(f"Could not load complete model: {str(e)}") - print("Attempting to load just the weights into a matching architecture...") - - # Try to inspect the model file to determine architecture - with h5py.File(model_path, 'r') as f: - model_config = None - if 'model_config' in f.attrs: - model_config = json.loads(f.attrs['model_config'].decode('utf-8')) - - # If we found a model config, try to recreate it - if model_config: - try: - model = tf.keras.models.model_from_json(json.dumps(model_config)) - model.load_weights(model_path) - print("Successfully loaded model from config and weights") - return model - except Exception as e2: - print(f"Failed to load from config: {str(e2)}") - - # If all else fails, return None - return None -st.markdown( - """ - - """, - unsafe_allow_html=True -) - - -# Create twinkling stars -stars_html = """ -
-""" -for i in range(100): - size = random.uniform(1, 3) - top = random.uniform(0, 100) - left = random.uniform(0, 100) - duration = random.uniform(3, 8) - opacity = random.uniform(0.2, 0.8) - - stars_html += f""" -
- """ -stars_html += "
" - -st.markdown(stars_html, unsafe_allow_html=True) - - - -from PIL import Image -import base64 -def get_image_as_base64(file_path): - try: - with open(file_path, "rb") as img_file: - return base64.b64encode(img_file.read()).decode() - except FileNotFoundError: - # Return a placeholder or empty string if file not found - st.warning(f"Image file not found: {file_path}") - return "" - -# Replace 'path/to/your/logo.png' with your actual logo path -logo_path = 'assets/logo2.png' -logo_base64 = get_image_as_base64(logo_path) - -st.markdown(f"""
-
""", unsafe_allow_html=True) - - -# Title and description -st.markdown(""" -
-

- SAR Image Colorization -

-

- Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning -

-
-""", unsafe_allow_html=True) - - - -# Create a card container -st.markdown("
", unsafe_allow_html=True) - -# Define the U-Net model -def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11): - inputs = Input(input_shape) - - # Encoder - conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) - batch1_1 = BatchNormalization()(conv1_1) - conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1) - batch1_2 = BatchNormalization()(conv1_2) - pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2) - - conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) - batch2_1 = BatchNormalization()(conv2_1) - conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1) - batch2_2 = BatchNormalization()(conv2_2) - pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2) - - conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) - batch3_1 = BatchNormalization()(conv3_1) - conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1) - batch3_2 = BatchNormalization()(conv3_2) - pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2) - - conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) - batch4_1 = BatchNormalization()(conv4_1) - conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1) - batch4_2 = BatchNormalization()(conv4_2) - drop4 = Dropout(drop_rate)(batch4_2) - pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) - - # Bridge - conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) - batch5_1 = BatchNormalization()(conv5_1) - conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1) - batch5_2 = BatchNormalization()(conv5_2) - drop5 = Dropout(drop_rate)(batch5_2) - - # Decoder - up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5)) - merge6 = concatenate([drop4, up6]) - conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) - batch6_1 = BatchNormalization()(conv6_1) - conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1) - batch6_2 = BatchNormalization()(conv6_2) - - up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2)) - merge7 = concatenate([batch3_2, up7]) - conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) - batch7_1 = BatchNormalization()(conv7_1) - conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1) - batch7_2 = BatchNormalization()(conv7_2) - - up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2)) - merge8 = concatenate([batch2_2, up8]) - conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) - batch8_1 = BatchNormalization()(conv8_1) - conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1) - batch8_2 = BatchNormalization()(conv8_2) - - up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2)) - merge9 = concatenate([batch1_2, up9]) - conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) - batch9_1 = BatchNormalization()(conv9_1) - conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1) - batch9_2 = BatchNormalization()(conv9_2) - - outputs = Conv2D(classes, 1, activation='softmax')(batch9_2) - - model = Model(inputs=inputs, outputs=outputs) - model.compile(optimizer=Adam(learning_rate=1e-4), - loss='categorical_crossentropy', - metrics=['accuracy']) - - return model - - -# Custom upsampling layer for dynamic resizing -class BilinearUpsampling(Layer): - def init(self, size=(1, 1), **kwargs): - super(BilinearUpsampling, self).init(**kwargs) - self.size = size - - def call(self, inputs): - return tf.image.resize(inputs, self.size, method='bilinear') - - def compute_output_shape(self, input_shape): - return (input_shape[0], self.size[0], self.size[1], input_shape[3]) - - def get_config(self): - config = super(BilinearUpsampling, self).get_config() - config.update({'size': self.size}) - return config - -def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16): - """ - DeepLabV3+ model with Xception backbone - - Args: - input_shape: Shape of input images - classes: Number of classes for segmentation - output_stride: Output stride for dilated convolutions (16 or 8) - - Returns: - model: DeepLabV3+ model - """ - # Input layer - inputs = Input(input_shape) - - # Ensure we're using the right dilation rates based on output_stride - if output_stride == 16: - atrous_rates = (6, 12, 18) - elif output_stride == 8: - atrous_rates = (12, 24, 36) - else: - raise ValueError("Output stride must be 8 or 16") - - # === ENCODER (BACKBONE) === - # Entry block - x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - x = Conv2D(64, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - # Xception-like blocks with dilated convolutions - # Block 1 - residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x) - residual = BatchNormalization()(residual) - - x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) - x = Add()([x, residual]) - - # Block 2 - residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x) - residual = BatchNormalization()(residual) - - x = Activation('relu')(x) - x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) - x = Add()([x, residual]) - - # Save low_level_features for skip connection (1/4 of input size) - low_level_features = x - - # Block 3 - residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x) - residual = BatchNormalization()(residual) - - x = Activation('relu')(x) - x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) - x = Add()([x, residual]) - - # Middle flow - modified with dilated convolutions - for i in range(16): - residual = x - - x = Activation('relu')(x) - x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) - x = BatchNormalization()(x) - - x = Add()([x, residual]) - - # Exit flow (modified) - x = Activation('relu')(x) - x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - - # === ASPP (Atrous Spatial Pyramid Pooling) === - # 1x1 convolution branch - aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x) - aspp_out1 = BatchNormalization()(aspp_out1) - aspp_out1 = Activation('relu')(aspp_out1) - - # 3x3 dilated convolution branches with different rates - aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x) - aspp_out2 = BatchNormalization()(aspp_out2) - aspp_out2 = Activation('relu')(aspp_out2) - - aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x) - aspp_out3 = BatchNormalization()(aspp_out3) - aspp_out3 = Activation('relu')(aspp_out3) - - aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x) - aspp_out4 = BatchNormalization()(aspp_out4) - aspp_out4 = Activation('relu')(aspp_out4) - - # Global pooling branch - aspp_out5 = GlobalAveragePooling2D()(x) - aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) # Use 1024 to match x's channels - aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5) - aspp_out5 = BatchNormalization()(aspp_out5) - aspp_out5 = Activation('relu')(aspp_out5) - - # Get current shape of x - _, height, width, _ = tf.keras.backend.int_shape(x) - aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5) - - # Concatenate all ASPP branches - aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5]) - - # Project ASPP output to 256 filters - aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out) - aspp_out = BatchNormalization()(aspp_out) - aspp_out = Activation('relu')(aspp_out) - - # === DECODER === - # Process low-level features from Block 2 (1/4 size) - low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features) - low_level_features = BatchNormalization()(low_level_features) - low_level_features = Activation('relu')(low_level_features) - - # Upsample ASPP output by 4x to match low level features size - # Get shapes for verification - low_level_shape = tf.keras.backend.int_shape(low_level_features) - - # Upsample to match low_level_features shape - x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1], - low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]), - interpolation='bilinear')(aspp_out) - - # Concatenate with low-level features - x = Concatenate()([x, low_level_features]) - - # Final convolutions - x = Conv2D(256, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - x = Conv2D(256, 3, padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - # Calculate upsampling size to original input size - x_shape = tf.keras.backend.int_shape(x) - upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2]) - - # Upsample to original size - x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x) - - # Final segmentation output - outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x) - - model = Model(inputs=inputs, outputs=outputs) - model.compile(optimizer=Adam(learning_rate=1e-4), - loss='categorical_crossentropy', - metrics=['accuracy']) - - return model - -import tensorflow as tf -from tensorflow.keras.models import Model -from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D - -def SegNet(input_shape=(256, 256, 1), classes=11): - """ - SegNet model for semantic segmentation - - Args: - input_shape: Shape of input images - classes: Number of classes for segmentation - - Returns: - model: SegNet model - """ - # Input layer - inputs = Input(input_shape) - - # === ENCODER === - # Encoder block 1 - x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - # Regular MaxPooling without indices - x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) - - # Encoder block 2 - x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) - - # Encoder block 3 - x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) - - # Encoder block 4 - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) - - # Encoder block 5 - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) - - # === DECODER === - # Using UpSampling2D instead of MaxUnpooling since TensorFlow doesn't support it - - # Decoder block 5 - x = UpSampling2D(size=(2, 2))(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - # Decoder block 4 - x = UpSampling2D(size=(2, 2))(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - # Decoder block 3 - x = UpSampling2D(size=(2, 2))(x) - x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - # Decoder block 2 - x = UpSampling2D(size=(2, 2))(x) - x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - # Decoder block 1 - x = UpSampling2D(size=(2, 2))(x) - x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) - - # Output layer - outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x) - - model = Model(inputs=inputs, outputs=outputs) - - return model - - - - -class SARSegmentation: - def _init_(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'): - self.img_rows = img_rows - self.img_cols = img_cols - self.drop_rate = drop_rate - self.num_channels = 1 # Single-pol SAR - self.model = None - self.model_type = model_type.lower() - - # ESA WorldCover class definitions - self.class_definitions = { - 'trees': 10, - 'shrubland': 20, - 'grassland': 30, - 'cropland': 40, - 'built_up': 50, - 'bare': 60, - 'snow': 70, - 'water': 80, - 'wetland': 90, - 'mangroves': 95, - 'moss': 100 - } - self.num_classes = len(self.class_definitions) - - # Class colors for visualization - self.class_colors = { - 0: [0, 100, 0], # Trees - Dark green - 1: [255, 165, 0], # Shrubland - Orange - 2: [144, 238, 144], # Grassland - Light green - 3: [255, 255, 0], # Cropland - Yellow - 4: [255, 0, 0], # Built-up - Red - 5: [139, 69, 19], # Bare - Brown - 6: [255, 255, 255], # Snow - White - 7: [0, 0, 255], # Water - Blue - 8: [0, 139, 139], # Wetland - Dark cyan - 9: [0, 255, 0], # Mangroves - Bright green - 10: [220, 220, 220] # Moss - Light grey - } - - def load_sar_data(self, file_path_or_bytes, is_bytes=False): - """Load SAR data from file path or bytes""" - try: - if is_bytes: - # Create a temporary file to use with rasterio - with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp: - tmp.write(file_path_or_bytes) - tmp_path = tmp.name - - with rasterio.open(tmp_path) as src: - sar_data = src.read(1) # Read single band - sar_data = np.expand_dims(sar_data, axis=-1) - - # Clean up the temporary file - os.unlink(tmp_path) - else: - with rasterio.open(file_path_or_bytes) as src: - sar_data = src.read(1) # Read single band - sar_data = np.expand_dims(sar_data, axis=-1) - - # Resize if needed - if sar_data.shape[:2] != (self.img_rows, self.img_cols): - sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows)) - sar_data = np.expand_dims(sar_data, axis=-1) - - return sar_data - except RasterioIOError: - # Try to open as a regular image if rasterio fails - if is_bytes: - img = Image.open(io.BytesIO(file_path_or_bytes)).convert('L') - else: - img = Image.open(file_path_or_bytes).convert('L') - - img_array = np.array(img) - if img_array.shape[:2] != (self.img_rows, self.img_cols): - img_array = cv2.resize(img_array, (self.img_cols, self.img_rows)) - - return np.expand_dims(img_array, axis=-1) - - def preprocess_sar(self, sar_data): - """Preprocess SAR data""" - # Check if data is already normalized (0-255 range) - if np.max(sar_data) <= 255 and np.min(sar_data) >= 0: - # Normalize to -1 to 1 range - sar_normalized = (sar_data / 127.5) - 1 - else: - # Assume it's in dB scale - sar_clipped = np.clip(sar_data, -50, 20) - sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 - - return sar_normalized - - def one_hot_encode(self, labels): - """Convert ESA WorldCover labels to one-hot encoded format""" - encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes)) - - for i, value in enumerate(sorted(self.class_definitions.values())): - encoded[:, :, i] = (labels == value) - - return encoded - - def load_trained_model(self, model_path): - """Load a trained model from file""" - try: - # If model_path is a filename without path, prepend the models directory - if not os.path.dirname(model_path) and not model_path.startswith('models/'): - model_path = os.path.join('models', os.path.basename(model_path)) - - # First try to load the complete model - self.model = load_model_with_weights(model_path) - - - if self.model is not None: - - has_dilated_convs = False - for layer in self.model.layers: - if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'): - if isinstance(layer.dilation_rate, (list, tuple)): - if any(rate > 1 for rate in layer.dilation_rate): - has_dilated_convs = True - break - elif layer.dilation_rate > 1: - has_dilated_convs = True - break - - if has_dilated_convs: - self.model_type = 'deeplabv3plus' - print("Detected DeepLabV3+ model") - # Check for SegNet architecture (typically has 5 encoder and 5 decoder blocks) - elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5: - self.model_type = 'segnet' - print("Detected SegNet model") - else: - self.model_type = 'unet' - print("Detected U-Net model") - - if self.model is None: - # If that fails, try to create a model with the expected architecture - if self.model_type == 'unet': - self.model = get_unet( - input_shape=(self.img_rows, self.img_cols, self.num_channels), - drop_rate=self.drop_rate, - classes=self.num_classes - ) - elif self.model_type == 'deeplabv3plus': - self.model = DeepLabV3Plus( - input_shape=(self.img_rows, self.img_cols, self.num_channels), - classes=self.num_classes - ) - elif self.model_type == 'segnet': - self.model = SegNet( - input_shape=(self.img_rows, self.img_cols, self.num_channels), - classes=self.num_classes - ) - else: - raise ValueError(f"Model type {self.model_type} not supported") - - # Try to load weights, allowing for mismatch - self.model.load_weights(model_path, by_name=True, skip_mismatch=True) - - # Check if any weights were loaded - if not any(np.any(w) for w in self.model.get_weights()): - raise ValueError("No weights were loaded. The model architecture is incompatible.") - except Exception as e: - raise ValueError(f"Failed to load model: {str(e)}") - - - - - def predict(self, sar_data): - """Predict segmentation for new SAR data""" - if self.model is None: - raise ValueError("Model not trained. Call train() first or load a trained model.") - - # Preprocess input data - sar_processed = self.preprocess_sar(sar_data) - - # Ensure correct shape - if len(sar_processed.shape) == 3: - sar_processed = np.expand_dims(sar_processed, axis=0) - - # Make prediction - prediction = self.model.predict(sar_processed) - return prediction - - def get_colored_prediction(self, prediction): - """Convert prediction to colored image""" - pred_class = np.argmax(prediction[0], axis=-1) - colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) - - for class_idx, color in self.class_colors.items(): - colored_pred[pred_class == class_idx] = color - - return colored_pred, pred_class - - - -def visualize_prediction(prediction, original_sar, figsize=(10, 4)): - """Visualize segmentation prediction with ESA WorldCover colors""" - # ESA WorldCover colors - colors = { - 0: [0, 100, 0], # Trees - Dark green - 1: [255, 165, 0], # Shrubland - Orange - 2: [144, 238, 144], # Grassland - Light green - 3: [255, 255, 0], # Cropland - Yellow - 4: [255, 0, 0], # Built-up - Red - 5: [139, 69, 19], # Bare - Brown - 6: [255, 255, 255], # Snow - White - 7: [0, 0, 255], # Water - Blue - 8: [0, 139, 139], # Wetland - Dark cyan - 9: [0, 255, 0], # Mangroves - Bright green - 10: [220, 220, 220] # Moss - Light grey - } - - # Convert prediction to color image - pred_class = np.argmax(prediction[0], axis=-1) - colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) - - for class_idx, color in colors.items(): - colored_pred[pred_class == class_idx] = color - - # Create overlay - sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB) - overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0) - - # Create figure - fig, axes = plt.subplots(1, 3, figsize=figsize) - - # Original SAR - axes[0].imshow(original_sar[:,:,0], cmap='gray') - axes[0].set_title('Original SAR', color='white') - axes[0].axis('off') - - # Prediction - axes[1].imshow(colored_pred) - axes[1].set_title('Prediction', color='white') - axes[1].axis('off') - - # Overlay - axes[2].imshow(overlay) - axes[2].set_title('Colorized Output', color='white') - axes[2].axis('off') - - # Set dark background - fig.patch.set_facecolor('#0a0a1f') - for ax in axes: - ax.set_facecolor('#0a0a1f') - - plt.tight_layout() - - # Convert plot to image - buf = io.BytesIO() - plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') - buf.seek(0) - plt.close(fig) - - return buf - -def create_legend(): - """Create a legend for the land cover classes""" - colors = { - 'Trees': [0, 100, 0], - 'Shrubland': [255, 165, 0], - 'Grassland': [144, 238, 144], - 'Cropland': [255, 255, 0], - 'Built-up': [255, 0, 0], - 'Bare': [139, 69, 19], - 'Snow': [255, 255, 255], - 'Water': [0, 0, 255], - 'Wetland': [0, 139, 139], - 'Mangroves': [0, 255, 0], - 'Moss': [220, 220, 220] - } - - fig, ax = plt.subplots(figsize=(8, 4)) - fig.patch.set_facecolor('#0a0a1f') - ax.set_facecolor('#0a0a1f') - - # Create color patches - for i, (class_name, color) in enumerate(colors.items()): - ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color])) - ax.text(0.7, i + 0.4, class_name, color='white', fontsize=12) - - ax.set_xlim(0, 3) - ax.set_ylim(-0.5, len(colors) - 0.5) - ax.set_title('Land Cover Classes', color='white', fontsize=14) - ax.axis('off') - - buf = io.BytesIO() - plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') - buf.seek(0) - plt.close(fig) - - return buf - -# Initialize session state variables -if 'model_loaded' not in st.session_state: - st.session_state.model_loaded = False -if 'segmentation' not in st.session_state: - st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) -if 'processed_images' not in st.session_state: - st.session_state.processed_images = [] - -# Create tabs -tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"]) - -# Add this code in tab1 before the model file uploader -with tab1: - st.markdown("

Load Segmentation Model

", unsafe_allow_html=True) - - # Add model type selection - model_type = st.selectbox( - "Select model architecture", - ["U-Net", "DeepLabV3+", "SegNet"], - index=0, - help="Select the architecture of the model you're uploading" - ) - - # Update the model type in the segmentation object - st.session_state.segmentation.model_type = model_type.lower().replace('-', '') - - col1, col2 = st.columns([3, 1]) - - with col1: - model_file = st.file_uploader("Upload model weights (.h5 file)", type=["h5"]) - - - with col2: - st.markdown("
", unsafe_allow_html=True) - if st.button("Load Model", key="load_model_btn"): - if model_file is not None: - with st.spinner("Loading model..."): - # Save the uploaded model to a temporary file - with tempfile.NamedTemporaryFile(delete=False, suffix='.h5') as tmp: - tmp.write(model_file.getvalue()) - model_path = tmp.name - - try: - # Load the model - st.session_state.segmentation.load_trained_model(model_path) - st.session_state.model_loaded = True - st.success("Model loaded successfully!") - except Exception as e: - st.error(f"Error loading model: {str(e)}") - finally: - # Clean up the temporary file - os.unlink(model_path) - else: - st.error("Please upload a model file (.h5)") - - - if st.session_state.model_loaded: - st.markdown("
", unsafe_allow_html=True) - st.markdown("

Model Information

", unsafe_allow_html=True) - - col1, col2, col3 = st.columns(3) - with col1: - st.markdown("
", unsafe_allow_html=True) - # Display the correct model architecture based on the detected model type - model_arch_map = { - 'unet': "U-Net", - 'deeplabv3plus': "DeepLabV3+", - 'segnet': "SegNet" - } - model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown") - st.markdown(f"

{model_arch}

", unsafe_allow_html=True) - st.markdown("

Architecture

", unsafe_allow_html=True) - st.markdown("
", unsafe_allow_html=True) - - - with col2: - st.markdown("
", unsafe_allow_html=True) - st.markdown("

11

", unsafe_allow_html=True) - st.markdown("

Land Cover Classes

", unsafe_allow_html=True) - st.markdown("
", unsafe_allow_html=True) - - with col3: - st.markdown("
", unsafe_allow_html=True) - st.markdown("

256 x 256

", unsafe_allow_html=True) - st.markdown("

Input Size

", unsafe_allow_html=True) - st.markdown("
", unsafe_allow_html=True) - - # Display legend - st.markdown("

Land Cover Classes

", unsafe_allow_html=True) - legend_img = create_legend() - st.image(legend_img, use_column_width=True) - - st.markdown("
", unsafe_allow_html=True) - else: - st.info("Please load a model to continue.") - - -with tab2: - st.markdown("

Process Single SAR Image

", unsafe_allow_html=True) - - if not st.session_state.model_loaded: - st.warning("Please load a model in the 'Load Model' tab first.") - else: - st.markdown("
", unsafe_allow_html=True) - uploaded_file = st.file_uploader("Upload a SAR image (.tif or common image formats)", type=["tif", "tiff", "png", "jpg", "jpeg"]) - st.markdown("
", unsafe_allow_html=True) - - if uploaded_file is not None: - if st.button("Process Image", key="process_single_btn"): - with st.spinner("Processing image..."): - # Load and process the image - try: - sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True) - - # Normalize for visualization - sar_normalized = sar_data.copy() - min_val = np.min(sar_normalized) - max_val = np.max(sar_normalized) - sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) - - # Make prediction - prediction = st.session_state.segmentation.predict(sar_data) - - # Visualize results - result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) - - # Display results - st.markdown("

Segmentation Results

", unsafe_allow_html=True) - st.image(result_img, use_column_width=True) - - # Add download button for the result - btn = st.download_button( - label="Download Result", - data=result_img, - file_name="segmentation_result.png", - mime="image/png", - key="download_single_result" - ) - except Exception as e: - st.error(f"Error processing image: {str(e)}") - -with tab3: - st.markdown("

Process Multiple SAR Images

", unsafe_allow_html=True) - - if not st.session_state.model_loaded: - st.warning("Please load a model in the 'Load Model' tab first.") - else: - st.markdown("
", unsafe_allow_html=True) - uploaded_files = st.file_uploader("Upload SAR images or a ZIP file containing images", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True) - st.markdown("
", unsafe_allow_html=True) - - col1, col2 = st.columns([3, 1]) - - with col1: - max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10) - - with col2: - st.markdown("
", unsafe_allow_html=True) - process_btn = st.button("Process Images", key="process_multi_btn") - - if process_btn and uploaded_files: - # Clear previous results - st.session_state.processed_images = [] - - # Process uploaded files - with st.spinner("Processing images..."): - # Create a temporary directory to extract zip files if needed - with tempfile.TemporaryDirectory() as temp_dir: - # Process each uploaded file - image_files = [] - - for uploaded_file in uploaded_files: - if uploaded_file.name.lower().endswith('.zip'): - # Extract zip file - zip_path = os.path.join(temp_dir, uploaded_file.name) - with open(zip_path, 'wb') as f: - f.write(uploaded_file.getvalue()) - - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - # Find all image files in the extracted directory - for root, _, files in os.walk(temp_dir): - for file in files: - if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): - image_files.append(os.path.join(root, file)) - else: - # Save the file to temp directory - file_path = os.path.join(temp_dir, uploaded_file.name) - with open(file_path, 'wb') as f: - f.write(uploaded_file.getvalue()) - image_files.append(file_path) - - # If there are too many images, randomly select a subset - if len(image_files) > max_images: - st.info(f"Found {len(image_files)} images. Randomly selecting {max_images} images to display.") - image_files = random.sample(image_files, max_images) - - # Process each image - progress_bar = st.progress(0) - for i, image_path in enumerate(image_files): - try: - # Update progress - progress_bar.progress((i + 1) / len(image_files)) - - # Load and process the image - sar_data = st.session_state.segmentation.load_sar_data(image_path) - - # Normalize for visualization - sar_normalized = sar_data.copy() - min_val = np.min(sar_normalized) - max_val = np.max(sar_normalized) - sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) - - # Make prediction - prediction = st.session_state.segmentation.predict(sar_data) - - # Visualize results - result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) - - # Add to processed images - st.session_state.processed_images.append({ - 'filename': os.path.basename(image_path), - 'result': result_img - }) - except Exception as e: - st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") - - # Clear progress bar - progress_bar.empty() - - # Display results - if st.session_state.processed_images: - st.markdown("

Segmentation Results

", unsafe_allow_html=True) - - # Create a zip file with all results - zip_buffer = io.BytesIO() - with zipfile.ZipFile(zip_buffer, 'w') as zip_file: - for i, img_data in enumerate(st.session_state.processed_images): - zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue()) - - # Add download button for all results - st.download_button( - label="Download All Results", - data=zip_buffer.getvalue(), - file_name="segmentation_results.zip", - mime="application/zip", - key="download_all_results" - ) - - # Display each result - for i, img_data in enumerate(st.session_state.processed_images): - st.markdown(f"
Image: {img_data['filename']}
", unsafe_allow_html=True) - st.image(img_data['result'], use_column_width=True) - st.markdown("
", unsafe_allow_html=True) - else: - st.warning("No images were successfully processed.") - elif process_btn: - st.warning("Please upload at least one image file or ZIP archive.") - -# Close the card container -st.markdown("
", unsafe_allow_html=True) - -# Footer -st.markdown(""" -
-

- SAR IMAGE COLORIZATION | VARUN & MOKSHYAGNA -

-
+import streamlit as st +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.layers import * +from tensorflow.keras.optimizers import Adam +import cv2 +import os +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import io +import base64 +import tempfile +import zipfile +import random +import time +import rasterio +from rasterio.errors import RasterioIOError +import h5py +import json + + +# Set page configuration +st.set_page_config( + page_title="SAR Image Pixel Level Classification", + page_icon="🛰", + layout="wide" +) + + +with st.sidebar: + st.image('assets/logo2.png', width=150) + st.title("About") + st.markdown(""" + ### SAR Image Colorization + + This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes. + + #### Features: + - Load pre-trained U-Net,DeepLabV3+ or SegNet models + - Process single SAR images + - Batch process multiple images + - Visualize Pixel Level Classification with ESA WorldCover color scheme + + #### Developed by: + Varun & Mokshyagna + + #### Technologies: + - TensorFlow/Keras + - Streamlit + - Rasterio + - OpenCV + + #### Version: + 1.0.0 + """) + + st.markdown("---") + st.markdown("© 2025 | All Rights Reserved") + + +def load_model_with_weights(model_path): + """Load a model directly from an H5 file, preserving the original architecture""" + # If model_path is a filename without path, prepend the models directory + if not os.path.dirname(model_path) and not model_path.startswith('models/'): + model_path = os.path.join('models', os.path.basename(model_path)) + + try: + # Try to load the complete model (architecture + weights) + model = tf.keras.models.load_model(model_path, compile=False) + print("Loaded complete model with architecture") + return model + except Exception as e: + print(f"Could not load complete model: {str(e)}") + print("Attempting to load just the weights into a matching architecture...") + + # Try to inspect the model file to determine architecture + with h5py.File(model_path, 'r') as f: + model_config = None + if 'model_config' in f.attrs: + # Handle both string and bytes attributes + config_attr = f.attrs['model_config'] + if isinstance(config_attr, bytes): + model_config = json.loads(config_attr.decode('utf-8')) + else: + model_config = json.loads(config_attr) + + # If we found a model config, try to recreate it + if model_config: + try: + model = tf.keras.models.model_from_json(json.dumps(model_config)) + model.load_weights(model_path) + print("Successfully loaded model from config and weights") + return model + except Exception as e2: + print(f"Failed to load from config: {str(e2)}") + + # If all else fails, return None + return None + +st.markdown( + """ + + """, + unsafe_allow_html=True +) + + +# Create twinkling stars +stars_html = """ +
+""" +for i in range(100): + size = random.uniform(1, 3) + top = random.uniform(0, 100) + left = random.uniform(0, 100) + duration = random.uniform(3, 8) + opacity = random.uniform(0.2, 0.8) + + stars_html += f""" +
+ """ +stars_html += "
" + +st.markdown(stars_html, unsafe_allow_html=True) + + + +from PIL import Image +import base64 +def get_image_as_base64(file_path): + try: + with open(file_path, "rb") as img_file: + return base64.b64encode(img_file.read()).decode() + except FileNotFoundError: + # Return a placeholder or empty string if file not found + st.warning(f"Image file not found: {file_path}") + return "" + +# Replace 'path/to/your/logo.png' with your actual logo path +logo_path = 'assets/logo2.png' +logo_base64 = get_image_as_base64(logo_path) + +st.markdown(f"""
+
""", unsafe_allow_html=True) + + +# Title and description +st.markdown(""" +
+

+ SAR Image Colorization +

+

+ Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning +

+
+""", unsafe_allow_html=True) + + + +# Create a card container +st.markdown("
", unsafe_allow_html=True) + +# Define the U-Net model +def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11): + inputs = Input(input_shape) + + # Encoder + conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) + batch1_1 = BatchNormalization()(conv1_1) + conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1) + batch1_2 = BatchNormalization()(conv1_2) + pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2) + + conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) + batch2_1 = BatchNormalization()(conv2_1) + conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1) + batch2_2 = BatchNormalization()(conv2_2) + pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2) + + conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) + batch3_1 = BatchNormalization()(conv3_1) + conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1) + batch3_2 = BatchNormalization()(conv3_2) + pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2) + + conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) + batch4_1 = BatchNormalization()(conv4_1) + conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1) + batch4_2 = BatchNormalization()(conv4_2) + drop4 = Dropout(drop_rate)(batch4_2) + pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) + + # Bridge + conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) + batch5_1 = BatchNormalization()(conv5_1) + conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1) + batch5_2 = BatchNormalization()(conv5_2) + drop5 = Dropout(drop_rate)(batch5_2) + + # Decoder + up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5)) + merge6 = concatenate([drop4, up6]) + conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) + batch6_1 = BatchNormalization()(conv6_1) + conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1) + batch6_2 = BatchNormalization()(conv6_2) + + up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2)) + merge7 = concatenate([batch3_2, up7]) + conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) + batch7_1 = BatchNormalization()(conv7_1) + conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1) + batch7_2 = BatchNormalization()(conv7_2) + + up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2)) + merge8 = concatenate([batch2_2, up8]) + conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) + batch8_1 = BatchNormalization()(conv8_1) + conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1) + batch8_2 = BatchNormalization()(conv8_2) + + up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2)) + merge9 = concatenate([batch1_2, up9]) + conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) + batch9_1 = BatchNormalization()(conv9_1) + conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1) + batch9_2 = BatchNormalization()(conv9_2) + + outputs = Conv2D(classes, 1, activation='softmax')(batch9_2) + + model = Model(inputs=inputs, outputs=outputs) + model.compile(optimizer=Adam(learning_rate=1e-4), + loss='categorical_crossentropy', + metrics=['accuracy']) + + return model + + +# Custom upsampling layer for dynamic resizing +class BilinearUpsampling(Layer): + def init(self, size=(1, 1), **kwargs): + super(BilinearUpsampling, self).init(**kwargs) + self.size = size + + def call(self, inputs): + return tf.image.resize(inputs, self.size, method='bilinear') + + def compute_output_shape(self, input_shape): + return (input_shape[0], self.size[0], self.size[1], input_shape[3]) + + def get_config(self): + config = super(BilinearUpsampling, self).get_config() + config.update({'size': self.size}) + return config + +def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16): + """ + DeepLabV3+ model with Xception backbone + + Args: + input_shape: Shape of input images + classes: Number of classes for segmentation + output_stride: Output stride for dilated convolutions (16 or 8) + + Returns: + model: DeepLabV3+ model + """ + # Input layer + inputs = Input(input_shape) + + # Ensure we're using the right dilation rates based on output_stride + if output_stride == 16: + atrous_rates = (6, 12, 18) + elif output_stride == 8: + atrous_rates = (12, 24, 36) + else: + raise ValueError("Output stride must be 8 or 16") + + # === ENCODER (BACKBONE) === + # Entry block + x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + x = Conv2D(64, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + # Xception-like blocks with dilated convolutions + # Block 1 + residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x) + residual = BatchNormalization()(residual) + + x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) + x = Add()([x, residual]) + + # Block 2 + residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x) + residual = BatchNormalization()(residual) + + x = Activation('relu')(x) + x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) + x = Add()([x, residual]) + + # Save low_level_features for skip connection (1/4 of input size) + low_level_features = x + + # Block 3 + residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x) + residual = BatchNormalization()(residual) + + x = Activation('relu')(x) + x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) + x = Add()([x, residual]) + + # Middle flow - modified with dilated convolutions + for i in range(16): + residual = x + + x = Activation('relu')(x) + x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) + x = BatchNormalization()(x) + + x = Add()([x, residual]) + + # Exit flow (modified) + x = Activation('relu')(x) + x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + + # === ASPP (Atrous Spatial Pyramid Pooling) === + # 1x1 convolution branch + aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x) + aspp_out1 = BatchNormalization()(aspp_out1) + aspp_out1 = Activation('relu')(aspp_out1) + + # 3x3 dilated convolution branches with different rates + aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x) + aspp_out2 = BatchNormalization()(aspp_out2) + aspp_out2 = Activation('relu')(aspp_out2) + + aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x) + aspp_out3 = BatchNormalization()(aspp_out3) + aspp_out3 = Activation('relu')(aspp_out3) + + aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x) + aspp_out4 = BatchNormalization()(aspp_out4) + aspp_out4 = Activation('relu')(aspp_out4) + + # Global pooling branch + aspp_out5 = GlobalAveragePooling2D()(x) + aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) # Use 1024 to match x's channels + aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5) + aspp_out5 = BatchNormalization()(aspp_out5) + aspp_out5 = Activation('relu')(aspp_out5) + + # Get current shape of x + _, height, width, _ = tf.keras.backend.int_shape(x) + aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5) + + # Concatenate all ASPP branches + aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5]) + + # Project ASPP output to 256 filters + aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out) + aspp_out = BatchNormalization()(aspp_out) + aspp_out = Activation('relu')(aspp_out) + + # === DECODER === + # Process low-level features from Block 2 (1/4 size) + low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features) + low_level_features = BatchNormalization()(low_level_features) + low_level_features = Activation('relu')(low_level_features) + + # Upsample ASPP output by 4x to match low level features size + # Get shapes for verification + low_level_shape = tf.keras.backend.int_shape(low_level_features) + + # Upsample to match low_level_features shape + x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1], + low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]), + interpolation='bilinear')(aspp_out) + + # Concatenate with low-level features + x = Concatenate()([x, low_level_features]) + + # Final convolutions + x = Conv2D(256, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + x = Conv2D(256, 3, padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + # Calculate upsampling size to original input size + x_shape = tf.keras.backend.int_shape(x) + upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2]) + + # Upsample to original size + x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x) + + # Final segmentation output + outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x) + + model = Model(inputs=inputs, outputs=outputs) + model.compile(optimizer=Adam(learning_rate=1e-4), + loss='categorical_crossentropy', + metrics=['accuracy']) + + return model + +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D + +def SegNet(input_shape=(256, 256, 1), classes=11): + """ + SegNet model for semantic segmentation + + Args: + input_shape: Shape of input images + classes: Number of classes for segmentation + + Returns: + model: SegNet model + """ + # Input layer + inputs = Input(input_shape) + + # === ENCODER === + # Encoder block 1 + x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + # Regular MaxPooling without indices + x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) + + # Encoder block 2 + x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) + + # Encoder block 3 + x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) + + # Encoder block 4 + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) + + # Encoder block 5 + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) + + # === DECODER === + # Using UpSampling2D instead of MaxUnpooling since TensorFlow doesn't support it + + # Decoder block 5 + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + # Decoder block 4 + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + # Decoder block 3 + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + # Decoder block 2 + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + # Decoder block 1 + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) + x = BatchNormalization()(x) + x = Activation('relu')(x) + + # Output layer + outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x) + + model = Model(inputs=inputs, outputs=outputs) + + return model + + + + +class SARSegmentation: + def __init__(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'): + self.img_rows = img_rows + self.img_cols = img_cols + self.drop_rate = drop_rate + self.num_channels = 1 # Single-pol SAR + self.model = None + self.model_type = model_type.lower() + + # ESA WorldCover class definitions + self.class_definitions = { + 'trees': 10, + 'shrubland': 20, + 'grassland': 30, + 'cropland': 40, + 'built_up': 50, + 'bare': 60, + 'snow': 70, + 'water': 80, + 'wetland': 90, + 'mangroves': 95, + 'moss': 100 + } + self.num_classes = len(self.class_definitions) + + # Class colors for visualization + self.class_colors = { + 0: [0, 100, 0], # Trees - Dark green + 1: [255, 165, 0], # Shrubland - Orange + 2: [144, 238, 144], # Grassland - Light green + 3: [255, 255, 0], # Cropland - Yellow + 4: [255, 0, 0], # Built-up - Red + 5: [139, 69, 19], # Bare - Brown + 6: [255, 255, 255], # Snow - White + 7: [0, 0, 255], # Water - Blue + 8: [0, 139, 139], # Wetland - Dark cyan + 9: [0, 255, 0], # Mangroves - Bright green + 10: [220, 220, 220] # Moss - Light grey + } + + def load_sar_data(self, file_path_or_bytes, is_bytes=False): + """Load SAR data from file path or bytes""" + try: + if is_bytes: + # Create a temporary file to use with rasterio + with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp: + tmp.write(file_path_or_bytes) + tmp_path = tmp.name + + with rasterio.open(tmp_path) as src: + sar_data = src.read(1) # Read single band + sar_data = np.expand_dims(sar_data, axis=-1) + + # Clean up the temporary file + os.unlink(tmp_path) + else: + with rasterio.open(file_path_or_bytes) as src: + sar_data = src.read(1) # Read single band + sar_data = np.expand_dims(sar_data, axis=-1) + + # Resize if needed + if sar_data.shape[:2] != (self.img_rows, self.img_cols): + sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows)) + sar_data = np.expand_dims(sar_data, axis=-1) + + return sar_data + except RasterioIOError: + # Try to open as a regular image if rasterio fails + if is_bytes: + img = Image.open(io.BytesIO(file_path_or_bytes)).convert('L') + else: + img = Image.open(file_path_or_bytes).convert('L') + + img_array = np.array(img) + if img_array.shape[:2] != (self.img_rows, self.img_cols): + img_array = cv2.resize(img_array, (self.img_cols, self.img_rows)) + + return np.expand_dims(img_array, axis=-1) + + def preprocess_sar(self, sar_data): + """Preprocess SAR data""" + # Check if data is already normalized (0-255 range) + if np.max(sar_data) <= 255 and np.min(sar_data) >= 0: + # Normalize to -1 to 1 range + sar_normalized = (sar_data / 127.5) - 1 + else: + # Assume it's in dB scale + sar_clipped = np.clip(sar_data, -50, 20) + sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 + + return sar_normalized + + def one_hot_encode(self, labels): + """Convert ESA WorldCover labels to one-hot encoded format""" + encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes)) + + for i, value in enumerate(sorted(self.class_definitions.values())): + encoded[:, :, i] = (labels == value) + + return encoded + + def load_trained_model(self, model_path): + """Load a trained model from file""" + try: + # If model_path is a filename without path, prepend the models directory + if not os.path.dirname(model_path) and not model_path.startswith('models/'): + model_path = os.path.join('models', os.path.basename(model_path)) + + # First try to load the complete model + self.model = load_model_with_weights(model_path) + + + if self.model is not None: + + has_dilated_convs = False + for layer in self.model.layers: + if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'): + if isinstance(layer.dilation_rate, (list, tuple)): + if any(rate > 1 for rate in layer.dilation_rate): + has_dilated_convs = True + break + elif layer.dilation_rate > 1: + has_dilated_convs = True + break + + if has_dilated_convs: + self.model_type = 'deeplabv3plus' + print("Detected DeepLabV3+ model") + # Check for SegNet architecture (typically has 5 encoder and 5 decoder blocks) + elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5: + self.model_type = 'segnet' + print("Detected SegNet model") + else: + self.model_type = 'unet' + print("Detected U-Net model") + + if self.model is None: + # If that fails, try to create a model with the expected architecture + if self.model_type == 'unet': + self.model = get_unet( + input_shape=(self.img_rows, self.img_cols, self.num_channels), + drop_rate=self.drop_rate, + classes=self.num_classes + ) + elif self.model_type == 'deeplabv3plus': + self.model = DeepLabV3Plus( + input_shape=(self.img_rows, self.img_cols, self.num_channels), + classes=self.num_classes + ) + elif self.model_type == 'segnet': + self.model = SegNet( + input_shape=(self.img_rows, self.img_cols, self.num_channels), + classes=self.num_classes + ) + else: + raise ValueError(f"Model type {self.model_type} not supported") + + # Try to load weights, allowing for mismatch + self.model.load_weights(model_path, by_name=True, skip_mismatch=True) + + # Check if any weights were loaded + if not any(np.any(w) for w in self.model.get_weights()): + raise ValueError("No weights were loaded. The model architecture is incompatible.") + except Exception as e: + raise ValueError(f"Failed to load model: {str(e)}") + + + + + def predict(self, sar_data): + """Predict segmentation for new SAR data""" + if self.model is None: + raise ValueError("Model not trained. Call train() first or load a trained model.") + + # Preprocess input data + sar_processed = self.preprocess_sar(sar_data) + + # Ensure correct shape + if len(sar_processed.shape) == 3: + sar_processed = np.expand_dims(sar_processed, axis=0) + + # Make prediction + prediction = self.model.predict(sar_processed) + return prediction + + def get_colored_prediction(self, prediction): + """Convert prediction to colored image""" + pred_class = np.argmax(prediction[0], axis=-1) + colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) + + for class_idx, color in self.class_colors.items(): + colored_pred[pred_class == class_idx] = color + + return colored_pred, pred_class + + + +def visualize_prediction(prediction, original_sar, figsize=(10, 4)): + """Visualize segmentation prediction with ESA WorldCover colors""" + # ESA WorldCover colors + colors = { + 0: [0, 100, 0], # Trees - Dark green + 1: [255, 165, 0], # Shrubland - Orange + 2: [144, 238, 144], # Grassland - Light green + 3: [255, 255, 0], # Cropland - Yellow + 4: [255, 0, 0], # Built-up - Red + 5: [139, 69, 19], # Bare - Brown + 6: [255, 255, 255], # Snow - White + 7: [0, 0, 255], # Water - Blue + 8: [0, 139, 139], # Wetland - Dark cyan + 9: [0, 255, 0], # Mangroves - Bright green + 10: [220, 220, 220] # Moss - Light grey + } + + # Convert prediction to color image + pred_class = np.argmax(prediction[0], axis=-1) + colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) + + for class_idx, color in colors.items(): + colored_pred[pred_class == class_idx] = color + + # Create overlay + sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB) + overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0) + + # Create figure + fig, axes = plt.subplots(1, 3, figsize=figsize) + + # Original SAR + axes[0].imshow(original_sar[:,:,0], cmap='gray') + axes[0].set_title('Original SAR', color='white') + axes[0].axis('off') + + # Prediction + axes[1].imshow(colored_pred) + axes[1].set_title('Prediction', color='white') + axes[1].axis('off') + + # Overlay + axes[2].imshow(overlay) + axes[2].set_title('Colorized Output', color='white') + axes[2].axis('off') + + # Set dark background + fig.patch.set_facecolor('#0a0a1f') + for ax in axes: + ax.set_facecolor('#0a0a1f') + + plt.tight_layout() + + # Convert plot to image + buf = io.BytesIO() + plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') + buf.seek(0) + plt.close(fig) + + return buf + +def create_legend(): + """Create a legend for the land cover classes""" + colors = { + 'Trees': [0, 100, 0], + 'Shrubland': [255, 165, 0], + 'Grassland': [144, 238, 144], + 'Cropland': [255, 255, 0], + 'Built-up': [255, 0, 0], + 'Bare': [139, 69, 19], + 'Snow': [255, 255, 255], + 'Water': [0, 0, 255], + 'Wetland': [0, 139, 139], + 'Mangroves': [0, 255, 0], + 'Moss': [220, 220, 220] + } + + fig, ax = plt.subplots(figsize=(8, 4)) + fig.patch.set_facecolor('#0a0a1f') + ax.set_facecolor('#0a0a1f') + + # Create color patches + for i, (class_name, color) in enumerate(colors.items()): + ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color])) + ax.text(0.7, i + 0.4, class_name, color='white', fontsize=12) + + ax.set_xlim(0, 3) + ax.set_ylim(-0.5, len(colors) - 0.5) + ax.set_title('Land Cover Classes', color='white', fontsize=14) + ax.axis('off') + + buf = io.BytesIO() + plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') + buf.seek(0) + plt.close(fig) + + return buf + +# Initialize session state variables +if 'model_loaded' not in st.session_state: + st.session_state.model_loaded = False +if 'segmentation' not in st.session_state: + st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) +if 'processed_images' not in st.session_state: + st.session_state.processed_images = [] + +# Create tabs +tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"]) + +# Add this code in tab1 before the model file uploader +with tab1: + st.markdown("

Load Segmentation Model

", unsafe_allow_html=True) + + # Add model type selection + model_type = st.selectbox( + "Select model architecture", + ["U-Net", "DeepLabV3+", "SegNet"], + index=0, + help="Select the architecture of the model you're uploading" + ) + + # Update the model type in the segmentation object + st.session_state.segmentation.model_type = model_type.lower().replace('-', '') + + col1, col2 = st.columns([3, 1]) + + with col1: + model_file = st.file_uploader("Upload model weights (.h5 file)", type=["h5"]) + + + with col2: + st.markdown("
", unsafe_allow_html=True) + if st.button("Load Model", key="load_model_btn"): + if model_file is not None: + with st.spinner("Loading model..."): + # Save the uploaded model to a temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix='.h5') as tmp: + tmp.write(model_file.getvalue()) + model_path = tmp.name + + try: + # Load the model + st.session_state.segmentation.load_trained_model(model_path) + st.session_state.model_loaded = True + st.success("Model loaded successfully!") + except Exception as e: + st.error(f"Error loading model: {str(e)}") + finally: + # Clean up the temporary file + os.unlink(model_path) + else: + st.error("Please upload a model file (.h5)") + + + if st.session_state.model_loaded: + st.markdown("
", unsafe_allow_html=True) + st.markdown("

Model Information

", unsafe_allow_html=True) + + col1, col2, col3 = st.columns(3) + with col1: + st.markdown("
", unsafe_allow_html=True) + # Display the correct model architecture based on the detected model type + model_arch_map = { + 'unet': "U-Net", + 'deeplabv3plus': "DeepLabV3+", + 'segnet': "SegNet" + } + model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown") + st.markdown(f"

{model_arch}

", unsafe_allow_html=True) + st.markdown("

Architecture

", unsafe_allow_html=True) + st.markdown("
", unsafe_allow_html=True) + + + with col2: + st.markdown("
", unsafe_allow_html=True) + st.markdown("

11

", unsafe_allow_html=True) + st.markdown("

Land Cover Classes

", unsafe_allow_html=True) + st.markdown("
", unsafe_allow_html=True) + + with col3: + st.markdown("
", unsafe_allow_html=True) + st.markdown("

256 x 256

", unsafe_allow_html=True) + st.markdown("

Input Size

", unsafe_allow_html=True) + st.markdown("
", unsafe_allow_html=True) + + # Display legend + st.markdown("

Land Cover Classes

", unsafe_allow_html=True) + legend_img = create_legend() + st.image(legend_img, use_column_width=True) + + st.markdown("
", unsafe_allow_html=True) + else: + st.info("Please load a model to continue.") + + +with tab2: + st.markdown("

Process Single SAR Image

", unsafe_allow_html=True) + + if not st.session_state.model_loaded: + st.warning("Please load a model in the 'Load Model' tab first.") + else: + st.markdown("
", unsafe_allow_html=True) + uploaded_file = st.file_uploader("Upload a SAR image (.tif or common image formats)", type=["tif", "tiff", "png", "jpg", "jpeg"]) + st.markdown("
", unsafe_allow_html=True) + + if uploaded_file is not None: + if st.button("Process Image", key="process_single_btn"): + with st.spinner("Processing image..."): + # Load and process the image + try: + sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True) + + # Normalize for visualization + sar_normalized = sar_data.copy() + min_val = np.min(sar_normalized) + max_val = np.max(sar_normalized) + sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) + + # Make prediction + prediction = st.session_state.segmentation.predict(sar_data) + + # Visualize results + result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) + + # Display results + st.markdown("

Segmentation Results

", unsafe_allow_html=True) + st.image(result_img, use_column_width=True) + + # Add download button for the result + btn = st.download_button( + label="Download Result", + data=result_img, + file_name="segmentation_result.png", + mime="image/png", + key="download_single_result" + ) + except Exception as e: + st.error(f"Error processing image: {str(e)}") + +with tab3: + st.markdown("

Process Multiple SAR Images

", unsafe_allow_html=True) + + if not st.session_state.model_loaded: + st.warning("Please load a model in the 'Load Model' tab first.") + else: + st.markdown("
", unsafe_allow_html=True) + uploaded_files = st.file_uploader("Upload SAR images or a ZIP file containing images", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True) + st.markdown("
", unsafe_allow_html=True) + + col1, col2 = st.columns([3, 1]) + + with col1: + max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10) + + with col2: + st.markdown("
", unsafe_allow_html=True) + process_btn = st.button("Process Images", key="process_multi_btn") + + if process_btn and uploaded_files: + # Clear previous results + st.session_state.processed_images = [] + + # Process uploaded files + with st.spinner("Processing images..."): + # Create a temporary directory to extract zip files if needed + with tempfile.TemporaryDirectory() as temp_dir: + # Process each uploaded file + image_files = [] + + for uploaded_file in uploaded_files: + if uploaded_file.name.lower().endswith('.zip'): + # Extract zip file + zip_path = os.path.join(temp_dir, uploaded_file.name) + with open(zip_path, 'wb') as f: + f.write(uploaded_file.getvalue()) + + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + # Find all image files in the extracted directory + for root, _, files in os.walk(temp_dir): + for file in files: + if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): + image_files.append(os.path.join(root, file)) + else: + # Save the file to temp directory + file_path = os.path.join(temp_dir, uploaded_file.name) + with open(file_path, 'wb') as f: + f.write(uploaded_file.getvalue()) + image_files.append(file_path) + + # If there are too many images, randomly select a subset + if len(image_files) > max_images: + st.info(f"Found {len(image_files)} images. Randomly selecting {max_images} images to display.") + image_files = random.sample(image_files, max_images) + + # Process each image + progress_bar = st.progress(0) + for i, image_path in enumerate(image_files): + try: + # Update progress + progress_bar.progress((i + 1) / len(image_files)) + + # Load and process the image + sar_data = st.session_state.segmentation.load_sar_data(image_path) + + # Normalize for visualization + sar_normalized = sar_data.copy() + min_val = np.min(sar_normalized) + max_val = np.max(sar_normalized) + sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) + + # Make prediction + prediction = st.session_state.segmentation.predict(sar_data) + + # Visualize results + result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) + + # Add to processed images + st.session_state.processed_images.append({ + 'filename': os.path.basename(image_path), + 'result': result_img + }) + except Exception as e: + st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") + + # Clear progress bar + progress_bar.empty() + + # Display results + if st.session_state.processed_images: + st.markdown("

Segmentation Results

", unsafe_allow_html=True) + + # Create a zip file with all results + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, 'w') as zip_file: + for i, img_data in enumerate(st.session_state.processed_images): + zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue()) + + # Add download button for all results + st.download_button( + label="Download All Results", + data=zip_buffer.getvalue(), + file_name="segmentation_results.zip", + mime="application/zip", + key="download_all_results" + ) + + # Display each result + for i, img_data in enumerate(st.session_state.processed_images): + st.markdown(f"
Image: {img_data['filename']}
", unsafe_allow_html=True) + st.image(img_data['result'], use_column_width=True) + st.markdown("
", unsafe_allow_html=True) + else: + st.warning("No images were successfully processed.") + elif process_btn: + st.warning("Please upload at least one image file or ZIP archive.") + +# Close the card container +st.markdown("
", unsafe_allow_html=True) + +# Footer +st.markdown(""" +
+

+ SAR IMAGE COLORIZATION | VARUN & MOKSHYAGNA +

+
""", unsafe_allow_html=True) \ No newline at end of file