diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,1391 +1,1397 @@
-import streamlit as st
-import tensorflow as tf
-from tensorflow.keras.models import Model
-from tensorflow.keras.layers import *
-from tensorflow.keras.optimizers import Adam
-import cv2
-import os
-import numpy as np
-import matplotlib.pyplot as plt
-from PIL import Image
-import io
-import base64
-import tempfile
-import zipfile
-import random
-import time
-import rasterio
-from rasterio.errors import RasterioIOError
-import h5py
-import json
-
-
-# Set page configuration
-st.set_page_config(
- page_title="SAR Image Pixel Level Classification",
- page_icon="🛰",
- layout="wide"
-)
-
-
-with st.sidebar:
- st.image('assets/logo2.png', width=150)
- st.title("About")
- st.markdown("""
- ### SAR Image Colorization
-
- This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes.
-
- #### Features:
- - Load pre-trained U-Net,DeepLabV3+ or SegNet models
- - Process single SAR images
- - Batch process multiple images
- - Visualize Pixel Level Classification with ESA WorldCover color scheme
-
- #### Developed by:
- Varun & Mokshyagna
-
- #### Technologies:
- - TensorFlow/Keras
- - Streamlit
- - Rasterio
- - OpenCV
-
- #### Version:
- 1.0.0
- """)
-
- st.markdown("---")
- st.markdown("© 2025 | All Rights Reserved")
-
-
-def load_model_with_weights(model_path):
- """Load a model directly from an H5 file, preserving the original architecture"""
- # If model_path is a filename without path, prepend the models directory
- if not os.path.dirname(model_path) and not model_path.startswith('models/'):
- model_path = os.path.join('models', os.path.basename(model_path))
-
- try:
- # Try to load the complete model (architecture + weights)
- model = tf.keras.models.load_model(model_path, compile=False)
- print("Loaded complete model with architecture")
- return model
- except Exception as e:
- print(f"Could not load complete model: {str(e)}")
- print("Attempting to load just the weights into a matching architecture...")
-
- # Try to inspect the model file to determine architecture
- with h5py.File(model_path, 'r') as f:
- model_config = None
- if 'model_config' in f.attrs:
- model_config = json.loads(f.attrs['model_config'].decode('utf-8'))
-
- # If we found a model config, try to recreate it
- if model_config:
- try:
- model = tf.keras.models.model_from_json(json.dumps(model_config))
- model.load_weights(model_path)
- print("Successfully loaded model from config and weights")
- return model
- except Exception as e2:
- print(f"Failed to load from config: {str(e2)}")
-
- # If all else fails, return None
- return None
-st.markdown(
- """
-
- """,
- unsafe_allow_html=True
-)
-
-
-# Create twinkling stars
-stars_html = """
-
-"""
-for i in range(100):
- size = random.uniform(1, 3)
- top = random.uniform(0, 100)
- left = random.uniform(0, 100)
- duration = random.uniform(3, 8)
- opacity = random.uniform(0.2, 0.8)
-
- stars_html += f"""
-
- """
-stars_html += "
"
-
-st.markdown(stars_html, unsafe_allow_html=True)
-
-
-
-from PIL import Image
-import base64
-def get_image_as_base64(file_path):
- try:
- with open(file_path, "rb") as img_file:
- return base64.b64encode(img_file.read()).decode()
- except FileNotFoundError:
- # Return a placeholder or empty string if file not found
- st.warning(f"Image file not found: {file_path}")
- return ""
-
-# Replace 'path/to/your/logo.png' with your actual logo path
-logo_path = 'assets/logo2.png'
-logo_base64 = get_image_as_base64(logo_path)
-
-st.markdown(f"""
-

""", unsafe_allow_html=True)
-
-
-# Title and description
-st.markdown("""
-
-
- SAR Image Colorization
-
-
- Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning
-
-
-""", unsafe_allow_html=True)
-
-
-
-# Create a card container
-st.markdown("", unsafe_allow_html=True)
-
-# Define the U-Net model
-def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11):
- inputs = Input(input_shape)
-
- # Encoder
- conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
- batch1_1 = BatchNormalization()(conv1_1)
- conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1)
- batch1_2 = BatchNormalization()(conv1_2)
- pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2)
-
- conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
- batch2_1 = BatchNormalization()(conv2_1)
- conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1)
- batch2_2 = BatchNormalization()(conv2_2)
- pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2)
-
- conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
- batch3_1 = BatchNormalization()(conv3_1)
- conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1)
- batch3_2 = BatchNormalization()(conv3_2)
- pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2)
-
- conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
- batch4_1 = BatchNormalization()(conv4_1)
- conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1)
- batch4_2 = BatchNormalization()(conv4_2)
- drop4 = Dropout(drop_rate)(batch4_2)
- pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
-
- # Bridge
- conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
- batch5_1 = BatchNormalization()(conv5_1)
- conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1)
- batch5_2 = BatchNormalization()(conv5_2)
- drop5 = Dropout(drop_rate)(batch5_2)
-
- # Decoder
- up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5))
- merge6 = concatenate([drop4, up6])
- conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
- batch6_1 = BatchNormalization()(conv6_1)
- conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1)
- batch6_2 = BatchNormalization()(conv6_2)
-
- up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2))
- merge7 = concatenate([batch3_2, up7])
- conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
- batch7_1 = BatchNormalization()(conv7_1)
- conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1)
- batch7_2 = BatchNormalization()(conv7_2)
-
- up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2))
- merge8 = concatenate([batch2_2, up8])
- conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
- batch8_1 = BatchNormalization()(conv8_1)
- conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1)
- batch8_2 = BatchNormalization()(conv8_2)
-
- up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2))
- merge9 = concatenate([batch1_2, up9])
- conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
- batch9_1 = BatchNormalization()(conv9_1)
- conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1)
- batch9_2 = BatchNormalization()(conv9_2)
-
- outputs = Conv2D(classes, 1, activation='softmax')(batch9_2)
-
- model = Model(inputs=inputs, outputs=outputs)
- model.compile(optimizer=Adam(learning_rate=1e-4),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-
- return model
-
-
-# Custom upsampling layer for dynamic resizing
-class BilinearUpsampling(Layer):
- def init(self, size=(1, 1), **kwargs):
- super(BilinearUpsampling, self).init(**kwargs)
- self.size = size
-
- def call(self, inputs):
- return tf.image.resize(inputs, self.size, method='bilinear')
-
- def compute_output_shape(self, input_shape):
- return (input_shape[0], self.size[0], self.size[1], input_shape[3])
-
- def get_config(self):
- config = super(BilinearUpsampling, self).get_config()
- config.update({'size': self.size})
- return config
-
-def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16):
- """
- DeepLabV3+ model with Xception backbone
-
- Args:
- input_shape: Shape of input images
- classes: Number of classes for segmentation
- output_stride: Output stride for dilated convolutions (16 or 8)
-
- Returns:
- model: DeepLabV3+ model
- """
- # Input layer
- inputs = Input(input_shape)
-
- # Ensure we're using the right dilation rates based on output_stride
- if output_stride == 16:
- atrous_rates = (6, 12, 18)
- elif output_stride == 8:
- atrous_rates = (12, 24, 36)
- else:
- raise ValueError("Output stride must be 8 or 16")
-
- # === ENCODER (BACKBONE) ===
- # Entry block
- x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- x = Conv2D(64, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- # Xception-like blocks with dilated convolutions
- # Block 1
- residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x)
- residual = BatchNormalization()(residual)
-
- x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
- x = Add()([x, residual])
-
- # Block 2
- residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x)
- residual = BatchNormalization()(residual)
-
- x = Activation('relu')(x)
- x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
- x = Add()([x, residual])
-
- # Save low_level_features for skip connection (1/4 of input size)
- low_level_features = x
-
- # Block 3
- residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x)
- residual = BatchNormalization()(residual)
-
- x = Activation('relu')(x)
- x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
- x = Add()([x, residual])
-
- # Middle flow - modified with dilated convolutions
- for i in range(16):
- residual = x
-
- x = Activation('relu')(x)
- x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
- x = BatchNormalization()(x)
-
- x = Add()([x, residual])
-
- # Exit flow (modified)
- x = Activation('relu')(x)
- x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
-
- # === ASPP (Atrous Spatial Pyramid Pooling) ===
- # 1x1 convolution branch
- aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x)
- aspp_out1 = BatchNormalization()(aspp_out1)
- aspp_out1 = Activation('relu')(aspp_out1)
-
- # 3x3 dilated convolution branches with different rates
- aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x)
- aspp_out2 = BatchNormalization()(aspp_out2)
- aspp_out2 = Activation('relu')(aspp_out2)
-
- aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x)
- aspp_out3 = BatchNormalization()(aspp_out3)
- aspp_out3 = Activation('relu')(aspp_out3)
-
- aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x)
- aspp_out4 = BatchNormalization()(aspp_out4)
- aspp_out4 = Activation('relu')(aspp_out4)
-
- # Global pooling branch
- aspp_out5 = GlobalAveragePooling2D()(x)
- aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) # Use 1024 to match x's channels
- aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5)
- aspp_out5 = BatchNormalization()(aspp_out5)
- aspp_out5 = Activation('relu')(aspp_out5)
-
- # Get current shape of x
- _, height, width, _ = tf.keras.backend.int_shape(x)
- aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5)
-
- # Concatenate all ASPP branches
- aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5])
-
- # Project ASPP output to 256 filters
- aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out)
- aspp_out = BatchNormalization()(aspp_out)
- aspp_out = Activation('relu')(aspp_out)
-
- # === DECODER ===
- # Process low-level features from Block 2 (1/4 size)
- low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features)
- low_level_features = BatchNormalization()(low_level_features)
- low_level_features = Activation('relu')(low_level_features)
-
- # Upsample ASPP output by 4x to match low level features size
- # Get shapes for verification
- low_level_shape = tf.keras.backend.int_shape(low_level_features)
-
- # Upsample to match low_level_features shape
- x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1],
- low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]),
- interpolation='bilinear')(aspp_out)
-
- # Concatenate with low-level features
- x = Concatenate()([x, low_level_features])
-
- # Final convolutions
- x = Conv2D(256, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- x = Conv2D(256, 3, padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- # Calculate upsampling size to original input size
- x_shape = tf.keras.backend.int_shape(x)
- upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2])
-
- # Upsample to original size
- x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x)
-
- # Final segmentation output
- outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x)
-
- model = Model(inputs=inputs, outputs=outputs)
- model.compile(optimizer=Adam(learning_rate=1e-4),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-
- return model
-
-import tensorflow as tf
-from tensorflow.keras.models import Model
-from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D
-
-def SegNet(input_shape=(256, 256, 1), classes=11):
- """
- SegNet model for semantic segmentation
-
- Args:
- input_shape: Shape of input images
- classes: Number of classes for segmentation
-
- Returns:
- model: SegNet model
- """
- # Input layer
- inputs = Input(input_shape)
-
- # === ENCODER ===
- # Encoder block 1
- x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- # Regular MaxPooling without indices
- x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
-
- # Encoder block 2
- x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
-
- # Encoder block 3
- x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
-
- # Encoder block 4
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
-
- # Encoder block 5
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
-
- # === DECODER ===
- # Using UpSampling2D instead of MaxUnpooling since TensorFlow doesn't support it
-
- # Decoder block 5
- x = UpSampling2D(size=(2, 2))(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- # Decoder block 4
- x = UpSampling2D(size=(2, 2))(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- # Decoder block 3
- x = UpSampling2D(size=(2, 2))(x)
- x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- # Decoder block 2
- x = UpSampling2D(size=(2, 2))(x)
- x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- # Decoder block 1
- x = UpSampling2D(size=(2, 2))(x)
- x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
- x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
- x = BatchNormalization()(x)
- x = Activation('relu')(x)
-
- # Output layer
- outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x)
-
- model = Model(inputs=inputs, outputs=outputs)
-
- return model
-
-
-
-
-class SARSegmentation:
- def _init_(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'):
- self.img_rows = img_rows
- self.img_cols = img_cols
- self.drop_rate = drop_rate
- self.num_channels = 1 # Single-pol SAR
- self.model = None
- self.model_type = model_type.lower()
-
- # ESA WorldCover class definitions
- self.class_definitions = {
- 'trees': 10,
- 'shrubland': 20,
- 'grassland': 30,
- 'cropland': 40,
- 'built_up': 50,
- 'bare': 60,
- 'snow': 70,
- 'water': 80,
- 'wetland': 90,
- 'mangroves': 95,
- 'moss': 100
- }
- self.num_classes = len(self.class_definitions)
-
- # Class colors for visualization
- self.class_colors = {
- 0: [0, 100, 0], # Trees - Dark green
- 1: [255, 165, 0], # Shrubland - Orange
- 2: [144, 238, 144], # Grassland - Light green
- 3: [255, 255, 0], # Cropland - Yellow
- 4: [255, 0, 0], # Built-up - Red
- 5: [139, 69, 19], # Bare - Brown
- 6: [255, 255, 255], # Snow - White
- 7: [0, 0, 255], # Water - Blue
- 8: [0, 139, 139], # Wetland - Dark cyan
- 9: [0, 255, 0], # Mangroves - Bright green
- 10: [220, 220, 220] # Moss - Light grey
- }
-
- def load_sar_data(self, file_path_or_bytes, is_bytes=False):
- """Load SAR data from file path or bytes"""
- try:
- if is_bytes:
- # Create a temporary file to use with rasterio
- with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp:
- tmp.write(file_path_or_bytes)
- tmp_path = tmp.name
-
- with rasterio.open(tmp_path) as src:
- sar_data = src.read(1) # Read single band
- sar_data = np.expand_dims(sar_data, axis=-1)
-
- # Clean up the temporary file
- os.unlink(tmp_path)
- else:
- with rasterio.open(file_path_or_bytes) as src:
- sar_data = src.read(1) # Read single band
- sar_data = np.expand_dims(sar_data, axis=-1)
-
- # Resize if needed
- if sar_data.shape[:2] != (self.img_rows, self.img_cols):
- sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows))
- sar_data = np.expand_dims(sar_data, axis=-1)
-
- return sar_data
- except RasterioIOError:
- # Try to open as a regular image if rasterio fails
- if is_bytes:
- img = Image.open(io.BytesIO(file_path_or_bytes)).convert('L')
- else:
- img = Image.open(file_path_or_bytes).convert('L')
-
- img_array = np.array(img)
- if img_array.shape[:2] != (self.img_rows, self.img_cols):
- img_array = cv2.resize(img_array, (self.img_cols, self.img_rows))
-
- return np.expand_dims(img_array, axis=-1)
-
- def preprocess_sar(self, sar_data):
- """Preprocess SAR data"""
- # Check if data is already normalized (0-255 range)
- if np.max(sar_data) <= 255 and np.min(sar_data) >= 0:
- # Normalize to -1 to 1 range
- sar_normalized = (sar_data / 127.5) - 1
- else:
- # Assume it's in dB scale
- sar_clipped = np.clip(sar_data, -50, 20)
- sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1
-
- return sar_normalized
-
- def one_hot_encode(self, labels):
- """Convert ESA WorldCover labels to one-hot encoded format"""
- encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes))
-
- for i, value in enumerate(sorted(self.class_definitions.values())):
- encoded[:, :, i] = (labels == value)
-
- return encoded
-
- def load_trained_model(self, model_path):
- """Load a trained model from file"""
- try:
- # If model_path is a filename without path, prepend the models directory
- if not os.path.dirname(model_path) and not model_path.startswith('models/'):
- model_path = os.path.join('models', os.path.basename(model_path))
-
- # First try to load the complete model
- self.model = load_model_with_weights(model_path)
-
-
- if self.model is not None:
-
- has_dilated_convs = False
- for layer in self.model.layers:
- if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'):
- if isinstance(layer.dilation_rate, (list, tuple)):
- if any(rate > 1 for rate in layer.dilation_rate):
- has_dilated_convs = True
- break
- elif layer.dilation_rate > 1:
- has_dilated_convs = True
- break
-
- if has_dilated_convs:
- self.model_type = 'deeplabv3plus'
- print("Detected DeepLabV3+ model")
- # Check for SegNet architecture (typically has 5 encoder and 5 decoder blocks)
- elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5:
- self.model_type = 'segnet'
- print("Detected SegNet model")
- else:
- self.model_type = 'unet'
- print("Detected U-Net model")
-
- if self.model is None:
- # If that fails, try to create a model with the expected architecture
- if self.model_type == 'unet':
- self.model = get_unet(
- input_shape=(self.img_rows, self.img_cols, self.num_channels),
- drop_rate=self.drop_rate,
- classes=self.num_classes
- )
- elif self.model_type == 'deeplabv3plus':
- self.model = DeepLabV3Plus(
- input_shape=(self.img_rows, self.img_cols, self.num_channels),
- classes=self.num_classes
- )
- elif self.model_type == 'segnet':
- self.model = SegNet(
- input_shape=(self.img_rows, self.img_cols, self.num_channels),
- classes=self.num_classes
- )
- else:
- raise ValueError(f"Model type {self.model_type} not supported")
-
- # Try to load weights, allowing for mismatch
- self.model.load_weights(model_path, by_name=True, skip_mismatch=True)
-
- # Check if any weights were loaded
- if not any(np.any(w) for w in self.model.get_weights()):
- raise ValueError("No weights were loaded. The model architecture is incompatible.")
- except Exception as e:
- raise ValueError(f"Failed to load model: {str(e)}")
-
-
-
-
- def predict(self, sar_data):
- """Predict segmentation for new SAR data"""
- if self.model is None:
- raise ValueError("Model not trained. Call train() first or load a trained model.")
-
- # Preprocess input data
- sar_processed = self.preprocess_sar(sar_data)
-
- # Ensure correct shape
- if len(sar_processed.shape) == 3:
- sar_processed = np.expand_dims(sar_processed, axis=0)
-
- # Make prediction
- prediction = self.model.predict(sar_processed)
- return prediction
-
- def get_colored_prediction(self, prediction):
- """Convert prediction to colored image"""
- pred_class = np.argmax(prediction[0], axis=-1)
- colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
-
- for class_idx, color in self.class_colors.items():
- colored_pred[pred_class == class_idx] = color
-
- return colored_pred, pred_class
-
-
-
-def visualize_prediction(prediction, original_sar, figsize=(10, 4)):
- """Visualize segmentation prediction with ESA WorldCover colors"""
- # ESA WorldCover colors
- colors = {
- 0: [0, 100, 0], # Trees - Dark green
- 1: [255, 165, 0], # Shrubland - Orange
- 2: [144, 238, 144], # Grassland - Light green
- 3: [255, 255, 0], # Cropland - Yellow
- 4: [255, 0, 0], # Built-up - Red
- 5: [139, 69, 19], # Bare - Brown
- 6: [255, 255, 255], # Snow - White
- 7: [0, 0, 255], # Water - Blue
- 8: [0, 139, 139], # Wetland - Dark cyan
- 9: [0, 255, 0], # Mangroves - Bright green
- 10: [220, 220, 220] # Moss - Light grey
- }
-
- # Convert prediction to color image
- pred_class = np.argmax(prediction[0], axis=-1)
- colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
-
- for class_idx, color in colors.items():
- colored_pred[pred_class == class_idx] = color
-
- # Create overlay
- sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB)
- overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0)
-
- # Create figure
- fig, axes = plt.subplots(1, 3, figsize=figsize)
-
- # Original SAR
- axes[0].imshow(original_sar[:,:,0], cmap='gray')
- axes[0].set_title('Original SAR', color='white')
- axes[0].axis('off')
-
- # Prediction
- axes[1].imshow(colored_pred)
- axes[1].set_title('Prediction', color='white')
- axes[1].axis('off')
-
- # Overlay
- axes[2].imshow(overlay)
- axes[2].set_title('Colorized Output', color='white')
- axes[2].axis('off')
-
- # Set dark background
- fig.patch.set_facecolor('#0a0a1f')
- for ax in axes:
- ax.set_facecolor('#0a0a1f')
-
- plt.tight_layout()
-
- # Convert plot to image
- buf = io.BytesIO()
- plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight')
- buf.seek(0)
- plt.close(fig)
-
- return buf
-
-def create_legend():
- """Create a legend for the land cover classes"""
- colors = {
- 'Trees': [0, 100, 0],
- 'Shrubland': [255, 165, 0],
- 'Grassland': [144, 238, 144],
- 'Cropland': [255, 255, 0],
- 'Built-up': [255, 0, 0],
- 'Bare': [139, 69, 19],
- 'Snow': [255, 255, 255],
- 'Water': [0, 0, 255],
- 'Wetland': [0, 139, 139],
- 'Mangroves': [0, 255, 0],
- 'Moss': [220, 220, 220]
- }
-
- fig, ax = plt.subplots(figsize=(8, 4))
- fig.patch.set_facecolor('#0a0a1f')
- ax.set_facecolor('#0a0a1f')
-
- # Create color patches
- for i, (class_name, color) in enumerate(colors.items()):
- ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color]))
- ax.text(0.7, i + 0.4, class_name, color='white', fontsize=12)
-
- ax.set_xlim(0, 3)
- ax.set_ylim(-0.5, len(colors) - 0.5)
- ax.set_title('Land Cover Classes', color='white', fontsize=14)
- ax.axis('off')
-
- buf = io.BytesIO()
- plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight')
- buf.seek(0)
- plt.close(fig)
-
- return buf
-
-# Initialize session state variables
-if 'model_loaded' not in st.session_state:
- st.session_state.model_loaded = False
-if 'segmentation' not in st.session_state:
- st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256)
-if 'processed_images' not in st.session_state:
- st.session_state.processed_images = []
-
-# Create tabs
-tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"])
-
-# Add this code in tab1 before the model file uploader
-with tab1:
- st.markdown("
Load Segmentation Model
", unsafe_allow_html=True)
-
- # Add model type selection
- model_type = st.selectbox(
- "Select model architecture",
- ["U-Net", "DeepLabV3+", "SegNet"],
- index=0,
- help="Select the architecture of the model you're uploading"
- )
-
- # Update the model type in the segmentation object
- st.session_state.segmentation.model_type = model_type.lower().replace('-', '')
-
- col1, col2 = st.columns([3, 1])
-
- with col1:
- model_file = st.file_uploader("Upload model weights (.h5 file)", type=["h5"])
-
-
- with col2:
- st.markdown("
", unsafe_allow_html=True)
- if st.button("Load Model", key="load_model_btn"):
- if model_file is not None:
- with st.spinner("Loading model..."):
- # Save the uploaded model to a temporary file
- with tempfile.NamedTemporaryFile(delete=False, suffix='.h5') as tmp:
- tmp.write(model_file.getvalue())
- model_path = tmp.name
-
- try:
- # Load the model
- st.session_state.segmentation.load_trained_model(model_path)
- st.session_state.model_loaded = True
- st.success("Model loaded successfully!")
- except Exception as e:
- st.error(f"Error loading model: {str(e)}")
- finally:
- # Clean up the temporary file
- os.unlink(model_path)
- else:
- st.error("Please upload a model file (.h5)")
-
-
- if st.session_state.model_loaded:
- st.markdown("
", unsafe_allow_html=True)
- st.markdown("
Model Information
", unsafe_allow_html=True)
-
- col1, col2, col3 = st.columns(3)
- with col1:
- st.markdown("
", unsafe_allow_html=True)
- # Display the correct model architecture based on the detected model type
- model_arch_map = {
- 'unet': "U-Net",
- 'deeplabv3plus': "DeepLabV3+",
- 'segnet': "SegNet"
- }
- model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown")
- st.markdown(f"
{model_arch}
", unsafe_allow_html=True)
- st.markdown("
Architecture
", unsafe_allow_html=True)
- st.markdown("
", unsafe_allow_html=True)
-
-
- with col2:
- st.markdown("
", unsafe_allow_html=True)
- st.markdown("
11
", unsafe_allow_html=True)
- st.markdown("
Land Cover Classes
", unsafe_allow_html=True)
- st.markdown("
", unsafe_allow_html=True)
-
- with col3:
- st.markdown("
", unsafe_allow_html=True)
- st.markdown("
256 x 256
", unsafe_allow_html=True)
- st.markdown("
Input Size
", unsafe_allow_html=True)
- st.markdown("
", unsafe_allow_html=True)
-
- # Display legend
- st.markdown("
Land Cover Classes
", unsafe_allow_html=True)
- legend_img = create_legend()
- st.image(legend_img, use_column_width=True)
-
- st.markdown("
", unsafe_allow_html=True)
- else:
- st.info("Please load a model to continue.")
-
-
-with tab2:
- st.markdown("
Process Single SAR Image
", unsafe_allow_html=True)
-
- if not st.session_state.model_loaded:
- st.warning("Please load a model in the 'Load Model' tab first.")
- else:
- st.markdown("
", unsafe_allow_html=True)
- uploaded_file = st.file_uploader("Upload a SAR image (.tif or common image formats)", type=["tif", "tiff", "png", "jpg", "jpeg"])
- st.markdown("
", unsafe_allow_html=True)
-
- if uploaded_file is not None:
- if st.button("Process Image", key="process_single_btn"):
- with st.spinner("Processing image..."):
- # Load and process the image
- try:
- sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True)
-
- # Normalize for visualization
- sar_normalized = sar_data.copy()
- min_val = np.min(sar_normalized)
- max_val = np.max(sar_normalized)
- sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8)
-
- # Make prediction
- prediction = st.session_state.segmentation.predict(sar_data)
-
- # Visualize results
- result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
-
- # Display results
- st.markdown("
Segmentation Results
", unsafe_allow_html=True)
- st.image(result_img, use_column_width=True)
-
- # Add download button for the result
- btn = st.download_button(
- label="Download Result",
- data=result_img,
- file_name="segmentation_result.png",
- mime="image/png",
- key="download_single_result"
- )
- except Exception as e:
- st.error(f"Error processing image: {str(e)}")
-
-with tab3:
- st.markdown("
Process Multiple SAR Images
", unsafe_allow_html=True)
-
- if not st.session_state.model_loaded:
- st.warning("Please load a model in the 'Load Model' tab first.")
- else:
- st.markdown("
", unsafe_allow_html=True)
- uploaded_files = st.file_uploader("Upload SAR images or a ZIP file containing images", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True)
- st.markdown("
", unsafe_allow_html=True)
-
- col1, col2 = st.columns([3, 1])
-
- with col1:
- max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10)
-
- with col2:
- st.markdown("
", unsafe_allow_html=True)
- process_btn = st.button("Process Images", key="process_multi_btn")
-
- if process_btn and uploaded_files:
- # Clear previous results
- st.session_state.processed_images = []
-
- # Process uploaded files
- with st.spinner("Processing images..."):
- # Create a temporary directory to extract zip files if needed
- with tempfile.TemporaryDirectory() as temp_dir:
- # Process each uploaded file
- image_files = []
-
- for uploaded_file in uploaded_files:
- if uploaded_file.name.lower().endswith('.zip'):
- # Extract zip file
- zip_path = os.path.join(temp_dir, uploaded_file.name)
- with open(zip_path, 'wb') as f:
- f.write(uploaded_file.getvalue())
-
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
- zip_ref.extractall(temp_dir)
-
- # Find all image files in the extracted directory
- for root, _, files in os.walk(temp_dir):
- for file in files:
- if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
- image_files.append(os.path.join(root, file))
- else:
- # Save the file to temp directory
- file_path = os.path.join(temp_dir, uploaded_file.name)
- with open(file_path, 'wb') as f:
- f.write(uploaded_file.getvalue())
- image_files.append(file_path)
-
- # If there are too many images, randomly select a subset
- if len(image_files) > max_images:
- st.info(f"Found {len(image_files)} images. Randomly selecting {max_images} images to display.")
- image_files = random.sample(image_files, max_images)
-
- # Process each image
- progress_bar = st.progress(0)
- for i, image_path in enumerate(image_files):
- try:
- # Update progress
- progress_bar.progress((i + 1) / len(image_files))
-
- # Load and process the image
- sar_data = st.session_state.segmentation.load_sar_data(image_path)
-
- # Normalize for visualization
- sar_normalized = sar_data.copy()
- min_val = np.min(sar_normalized)
- max_val = np.max(sar_normalized)
- sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8)
-
- # Make prediction
- prediction = st.session_state.segmentation.predict(sar_data)
-
- # Visualize results
- result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
-
- # Add to processed images
- st.session_state.processed_images.append({
- 'filename': os.path.basename(image_path),
- 'result': result_img
- })
- except Exception as e:
- st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}")
-
- # Clear progress bar
- progress_bar.empty()
-
- # Display results
- if st.session_state.processed_images:
- st.markdown("
Segmentation Results
", unsafe_allow_html=True)
-
- # Create a zip file with all results
- zip_buffer = io.BytesIO()
- with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
- for i, img_data in enumerate(st.session_state.processed_images):
- zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue())
-
- # Add download button for all results
- st.download_button(
- label="Download All Results",
- data=zip_buffer.getvalue(),
- file_name="segmentation_results.zip",
- mime="application/zip",
- key="download_all_results"
- )
-
- # Display each result
- for i, img_data in enumerate(st.session_state.processed_images):
- st.markdown(f"
Image: {img_data['filename']}
", unsafe_allow_html=True)
- st.image(img_data['result'], use_column_width=True)
- st.markdown("
", unsafe_allow_html=True)
- else:
- st.warning("No images were successfully processed.")
- elif process_btn:
- st.warning("Please upload at least one image file or ZIP archive.")
-
-# Close the card container
-st.markdown("
", unsafe_allow_html=True)
-
-# Footer
-st.markdown("""
-
-
- SAR IMAGE COLORIZATION | VARUN & MOKSHYAGNA
-
-
+import streamlit as st
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import *
+from tensorflow.keras.optimizers import Adam
+import cv2
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+from PIL import Image
+import io
+import base64
+import tempfile
+import zipfile
+import random
+import time
+import rasterio
+from rasterio.errors import RasterioIOError
+import h5py
+import json
+
+
+# Set page configuration
+st.set_page_config(
+ page_title="SAR Image Pixel Level Classification",
+ page_icon="🛰",
+ layout="wide"
+)
+
+
+with st.sidebar:
+ st.image('assets/logo2.png', width=150)
+ st.title("About")
+ st.markdown("""
+ ### SAR Image Colorization
+
+ This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes.
+
+ #### Features:
+ - Load pre-trained U-Net,DeepLabV3+ or SegNet models
+ - Process single SAR images
+ - Batch process multiple images
+ - Visualize Pixel Level Classification with ESA WorldCover color scheme
+
+ #### Developed by:
+ Varun & Mokshyagna
+
+ #### Technologies:
+ - TensorFlow/Keras
+ - Streamlit
+ - Rasterio
+ - OpenCV
+
+ #### Version:
+ 1.0.0
+ """)
+
+ st.markdown("---")
+ st.markdown("© 2025 | All Rights Reserved")
+
+
+def load_model_with_weights(model_path):
+ """Load a model directly from an H5 file, preserving the original architecture"""
+ # If model_path is a filename without path, prepend the models directory
+ if not os.path.dirname(model_path) and not model_path.startswith('models/'):
+ model_path = os.path.join('models', os.path.basename(model_path))
+
+ try:
+ # Try to load the complete model (architecture + weights)
+ model = tf.keras.models.load_model(model_path, compile=False)
+ print("Loaded complete model with architecture")
+ return model
+ except Exception as e:
+ print(f"Could not load complete model: {str(e)}")
+ print("Attempting to load just the weights into a matching architecture...")
+
+ # Try to inspect the model file to determine architecture
+ with h5py.File(model_path, 'r') as f:
+ model_config = None
+ if 'model_config' in f.attrs:
+ # Handle both string and bytes attributes
+ config_attr = f.attrs['model_config']
+ if isinstance(config_attr, bytes):
+ model_config = json.loads(config_attr.decode('utf-8'))
+ else:
+ model_config = json.loads(config_attr)
+
+ # If we found a model config, try to recreate it
+ if model_config:
+ try:
+ model = tf.keras.models.model_from_json(json.dumps(model_config))
+ model.load_weights(model_path)
+ print("Successfully loaded model from config and weights")
+ return model
+ except Exception as e2:
+ print(f"Failed to load from config: {str(e2)}")
+
+ # If all else fails, return None
+ return None
+
+st.markdown(
+ """
+
+ """,
+ unsafe_allow_html=True
+)
+
+
+# Create twinkling stars
+stars_html = """
+
+"""
+for i in range(100):
+ size = random.uniform(1, 3)
+ top = random.uniform(0, 100)
+ left = random.uniform(0, 100)
+ duration = random.uniform(3, 8)
+ opacity = random.uniform(0.2, 0.8)
+
+ stars_html += f"""
+
+ """
+stars_html += "
"
+
+st.markdown(stars_html, unsafe_allow_html=True)
+
+
+
+from PIL import Image
+import base64
+def get_image_as_base64(file_path):
+ try:
+ with open(file_path, "rb") as img_file:
+ return base64.b64encode(img_file.read()).decode()
+ except FileNotFoundError:
+ # Return a placeholder or empty string if file not found
+ st.warning(f"Image file not found: {file_path}")
+ return ""
+
+# Replace 'path/to/your/logo.png' with your actual logo path
+logo_path = 'assets/logo2.png'
+logo_base64 = get_image_as_base64(logo_path)
+
+st.markdown(f"""
+

""", unsafe_allow_html=True)
+
+
+# Title and description
+st.markdown("""
+
+
+ SAR Image Colorization
+
+
+ Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning
+
+
+""", unsafe_allow_html=True)
+
+
+
+# Create a card container
+st.markdown("", unsafe_allow_html=True)
+
+# Define the U-Net model
+def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11):
+ inputs = Input(input_shape)
+
+ # Encoder
+ conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
+ batch1_1 = BatchNormalization()(conv1_1)
+ conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1)
+ batch1_2 = BatchNormalization()(conv1_2)
+ pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2)
+
+ conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
+ batch2_1 = BatchNormalization()(conv2_1)
+ conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1)
+ batch2_2 = BatchNormalization()(conv2_2)
+ pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2)
+
+ conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
+ batch3_1 = BatchNormalization()(conv3_1)
+ conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1)
+ batch3_2 = BatchNormalization()(conv3_2)
+ pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2)
+
+ conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
+ batch4_1 = BatchNormalization()(conv4_1)
+ conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1)
+ batch4_2 = BatchNormalization()(conv4_2)
+ drop4 = Dropout(drop_rate)(batch4_2)
+ pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
+
+ # Bridge
+ conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
+ batch5_1 = BatchNormalization()(conv5_1)
+ conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1)
+ batch5_2 = BatchNormalization()(conv5_2)
+ drop5 = Dropout(drop_rate)(batch5_2)
+
+ # Decoder
+ up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5))
+ merge6 = concatenate([drop4, up6])
+ conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
+ batch6_1 = BatchNormalization()(conv6_1)
+ conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1)
+ batch6_2 = BatchNormalization()(conv6_2)
+
+ up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2))
+ merge7 = concatenate([batch3_2, up7])
+ conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
+ batch7_1 = BatchNormalization()(conv7_1)
+ conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1)
+ batch7_2 = BatchNormalization()(conv7_2)
+
+ up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2))
+ merge8 = concatenate([batch2_2, up8])
+ conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
+ batch8_1 = BatchNormalization()(conv8_1)
+ conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1)
+ batch8_2 = BatchNormalization()(conv8_2)
+
+ up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2))
+ merge9 = concatenate([batch1_2, up9])
+ conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
+ batch9_1 = BatchNormalization()(conv9_1)
+ conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1)
+ batch9_2 = BatchNormalization()(conv9_2)
+
+ outputs = Conv2D(classes, 1, activation='softmax')(batch9_2)
+
+ model = Model(inputs=inputs, outputs=outputs)
+ model.compile(optimizer=Adam(learning_rate=1e-4),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ return model
+
+
+# Custom upsampling layer for dynamic resizing
+class BilinearUpsampling(Layer):
+ def init(self, size=(1, 1), **kwargs):
+ super(BilinearUpsampling, self).init(**kwargs)
+ self.size = size
+
+ def call(self, inputs):
+ return tf.image.resize(inputs, self.size, method='bilinear')
+
+ def compute_output_shape(self, input_shape):
+ return (input_shape[0], self.size[0], self.size[1], input_shape[3])
+
+ def get_config(self):
+ config = super(BilinearUpsampling, self).get_config()
+ config.update({'size': self.size})
+ return config
+
+def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16):
+ """
+ DeepLabV3+ model with Xception backbone
+
+ Args:
+ input_shape: Shape of input images
+ classes: Number of classes for segmentation
+ output_stride: Output stride for dilated convolutions (16 or 8)
+
+ Returns:
+ model: DeepLabV3+ model
+ """
+ # Input layer
+ inputs = Input(input_shape)
+
+ # Ensure we're using the right dilation rates based on output_stride
+ if output_stride == 16:
+ atrous_rates = (6, 12, 18)
+ elif output_stride == 8:
+ atrous_rates = (12, 24, 36)
+ else:
+ raise ValueError("Output stride must be 8 or 16")
+
+ # === ENCODER (BACKBONE) ===
+ # Entry block
+ x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ x = Conv2D(64, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ # Xception-like blocks with dilated convolutions
+ # Block 1
+ residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x)
+ residual = BatchNormalization()(residual)
+
+ x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
+ x = Add()([x, residual])
+
+ # Block 2
+ residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x)
+ residual = BatchNormalization()(residual)
+
+ x = Activation('relu')(x)
+ x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
+ x = Add()([x, residual])
+
+ # Save low_level_features for skip connection (1/4 of input size)
+ low_level_features = x
+
+ # Block 3
+ residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x)
+ residual = BatchNormalization()(residual)
+
+ x = Activation('relu')(x)
+ x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
+ x = Add()([x, residual])
+
+ # Middle flow - modified with dilated convolutions
+ for i in range(16):
+ residual = x
+
+ x = Activation('relu')(x)
+ x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
+ x = BatchNormalization()(x)
+
+ x = Add()([x, residual])
+
+ # Exit flow (modified)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+
+ # === ASPP (Atrous Spatial Pyramid Pooling) ===
+ # 1x1 convolution branch
+ aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x)
+ aspp_out1 = BatchNormalization()(aspp_out1)
+ aspp_out1 = Activation('relu')(aspp_out1)
+
+ # 3x3 dilated convolution branches with different rates
+ aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x)
+ aspp_out2 = BatchNormalization()(aspp_out2)
+ aspp_out2 = Activation('relu')(aspp_out2)
+
+ aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x)
+ aspp_out3 = BatchNormalization()(aspp_out3)
+ aspp_out3 = Activation('relu')(aspp_out3)
+
+ aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x)
+ aspp_out4 = BatchNormalization()(aspp_out4)
+ aspp_out4 = Activation('relu')(aspp_out4)
+
+ # Global pooling branch
+ aspp_out5 = GlobalAveragePooling2D()(x)
+ aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) # Use 1024 to match x's channels
+ aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5)
+ aspp_out5 = BatchNormalization()(aspp_out5)
+ aspp_out5 = Activation('relu')(aspp_out5)
+
+ # Get current shape of x
+ _, height, width, _ = tf.keras.backend.int_shape(x)
+ aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5)
+
+ # Concatenate all ASPP branches
+ aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5])
+
+ # Project ASPP output to 256 filters
+ aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out)
+ aspp_out = BatchNormalization()(aspp_out)
+ aspp_out = Activation('relu')(aspp_out)
+
+ # === DECODER ===
+ # Process low-level features from Block 2 (1/4 size)
+ low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features)
+ low_level_features = BatchNormalization()(low_level_features)
+ low_level_features = Activation('relu')(low_level_features)
+
+ # Upsample ASPP output by 4x to match low level features size
+ # Get shapes for verification
+ low_level_shape = tf.keras.backend.int_shape(low_level_features)
+
+ # Upsample to match low_level_features shape
+ x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1],
+ low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]),
+ interpolation='bilinear')(aspp_out)
+
+ # Concatenate with low-level features
+ x = Concatenate()([x, low_level_features])
+
+ # Final convolutions
+ x = Conv2D(256, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ x = Conv2D(256, 3, padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ # Calculate upsampling size to original input size
+ x_shape = tf.keras.backend.int_shape(x)
+ upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2])
+
+ # Upsample to original size
+ x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x)
+
+ # Final segmentation output
+ outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x)
+
+ model = Model(inputs=inputs, outputs=outputs)
+ model.compile(optimizer=Adam(learning_rate=1e-4),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ return model
+
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D
+
+def SegNet(input_shape=(256, 256, 1), classes=11):
+ """
+ SegNet model for semantic segmentation
+
+ Args:
+ input_shape: Shape of input images
+ classes: Number of classes for segmentation
+
+ Returns:
+ model: SegNet model
+ """
+ # Input layer
+ inputs = Input(input_shape)
+
+ # === ENCODER ===
+ # Encoder block 1
+ x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ # Regular MaxPooling without indices
+ x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
+
+ # Encoder block 2
+ x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
+
+ # Encoder block 3
+ x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
+
+ # Encoder block 4
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
+
+ # Encoder block 5
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
+
+ # === DECODER ===
+ # Using UpSampling2D instead of MaxUnpooling since TensorFlow doesn't support it
+
+ # Decoder block 5
+ x = UpSampling2D(size=(2, 2))(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ # Decoder block 4
+ x = UpSampling2D(size=(2, 2))(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ # Decoder block 3
+ x = UpSampling2D(size=(2, 2))(x)
+ x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ # Decoder block 2
+ x = UpSampling2D(size=(2, 2))(x)
+ x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ # Decoder block 1
+ x = UpSampling2D(size=(2, 2))(x)
+ x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+
+ # Output layer
+ outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x)
+
+ model = Model(inputs=inputs, outputs=outputs)
+
+ return model
+
+
+
+
+class SARSegmentation:
+ def __init__(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'):
+ self.img_rows = img_rows
+ self.img_cols = img_cols
+ self.drop_rate = drop_rate
+ self.num_channels = 1 # Single-pol SAR
+ self.model = None
+ self.model_type = model_type.lower()
+
+ # ESA WorldCover class definitions
+ self.class_definitions = {
+ 'trees': 10,
+ 'shrubland': 20,
+ 'grassland': 30,
+ 'cropland': 40,
+ 'built_up': 50,
+ 'bare': 60,
+ 'snow': 70,
+ 'water': 80,
+ 'wetland': 90,
+ 'mangroves': 95,
+ 'moss': 100
+ }
+ self.num_classes = len(self.class_definitions)
+
+ # Class colors for visualization
+ self.class_colors = {
+ 0: [0, 100, 0], # Trees - Dark green
+ 1: [255, 165, 0], # Shrubland - Orange
+ 2: [144, 238, 144], # Grassland - Light green
+ 3: [255, 255, 0], # Cropland - Yellow
+ 4: [255, 0, 0], # Built-up - Red
+ 5: [139, 69, 19], # Bare - Brown
+ 6: [255, 255, 255], # Snow - White
+ 7: [0, 0, 255], # Water - Blue
+ 8: [0, 139, 139], # Wetland - Dark cyan
+ 9: [0, 255, 0], # Mangroves - Bright green
+ 10: [220, 220, 220] # Moss - Light grey
+ }
+
+ def load_sar_data(self, file_path_or_bytes, is_bytes=False):
+ """Load SAR data from file path or bytes"""
+ try:
+ if is_bytes:
+ # Create a temporary file to use with rasterio
+ with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp:
+ tmp.write(file_path_or_bytes)
+ tmp_path = tmp.name
+
+ with rasterio.open(tmp_path) as src:
+ sar_data = src.read(1) # Read single band
+ sar_data = np.expand_dims(sar_data, axis=-1)
+
+ # Clean up the temporary file
+ os.unlink(tmp_path)
+ else:
+ with rasterio.open(file_path_or_bytes) as src:
+ sar_data = src.read(1) # Read single band
+ sar_data = np.expand_dims(sar_data, axis=-1)
+
+ # Resize if needed
+ if sar_data.shape[:2] != (self.img_rows, self.img_cols):
+ sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows))
+ sar_data = np.expand_dims(sar_data, axis=-1)
+
+ return sar_data
+ except RasterioIOError:
+ # Try to open as a regular image if rasterio fails
+ if is_bytes:
+ img = Image.open(io.BytesIO(file_path_or_bytes)).convert('L')
+ else:
+ img = Image.open(file_path_or_bytes).convert('L')
+
+ img_array = np.array(img)
+ if img_array.shape[:2] != (self.img_rows, self.img_cols):
+ img_array = cv2.resize(img_array, (self.img_cols, self.img_rows))
+
+ return np.expand_dims(img_array, axis=-1)
+
+ def preprocess_sar(self, sar_data):
+ """Preprocess SAR data"""
+ # Check if data is already normalized (0-255 range)
+ if np.max(sar_data) <= 255 and np.min(sar_data) >= 0:
+ # Normalize to -1 to 1 range
+ sar_normalized = (sar_data / 127.5) - 1
+ else:
+ # Assume it's in dB scale
+ sar_clipped = np.clip(sar_data, -50, 20)
+ sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1
+
+ return sar_normalized
+
+ def one_hot_encode(self, labels):
+ """Convert ESA WorldCover labels to one-hot encoded format"""
+ encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes))
+
+ for i, value in enumerate(sorted(self.class_definitions.values())):
+ encoded[:, :, i] = (labels == value)
+
+ return encoded
+
+ def load_trained_model(self, model_path):
+ """Load a trained model from file"""
+ try:
+ # If model_path is a filename without path, prepend the models directory
+ if not os.path.dirname(model_path) and not model_path.startswith('models/'):
+ model_path = os.path.join('models', os.path.basename(model_path))
+
+ # First try to load the complete model
+ self.model = load_model_with_weights(model_path)
+
+
+ if self.model is not None:
+
+ has_dilated_convs = False
+ for layer in self.model.layers:
+ if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'):
+ if isinstance(layer.dilation_rate, (list, tuple)):
+ if any(rate > 1 for rate in layer.dilation_rate):
+ has_dilated_convs = True
+ break
+ elif layer.dilation_rate > 1:
+ has_dilated_convs = True
+ break
+
+ if has_dilated_convs:
+ self.model_type = 'deeplabv3plus'
+ print("Detected DeepLabV3+ model")
+ # Check for SegNet architecture (typically has 5 encoder and 5 decoder blocks)
+ elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5:
+ self.model_type = 'segnet'
+ print("Detected SegNet model")
+ else:
+ self.model_type = 'unet'
+ print("Detected U-Net model")
+
+ if self.model is None:
+ # If that fails, try to create a model with the expected architecture
+ if self.model_type == 'unet':
+ self.model = get_unet(
+ input_shape=(self.img_rows, self.img_cols, self.num_channels),
+ drop_rate=self.drop_rate,
+ classes=self.num_classes
+ )
+ elif self.model_type == 'deeplabv3plus':
+ self.model = DeepLabV3Plus(
+ input_shape=(self.img_rows, self.img_cols, self.num_channels),
+ classes=self.num_classes
+ )
+ elif self.model_type == 'segnet':
+ self.model = SegNet(
+ input_shape=(self.img_rows, self.img_cols, self.num_channels),
+ classes=self.num_classes
+ )
+ else:
+ raise ValueError(f"Model type {self.model_type} not supported")
+
+ # Try to load weights, allowing for mismatch
+ self.model.load_weights(model_path, by_name=True, skip_mismatch=True)
+
+ # Check if any weights were loaded
+ if not any(np.any(w) for w in self.model.get_weights()):
+ raise ValueError("No weights were loaded. The model architecture is incompatible.")
+ except Exception as e:
+ raise ValueError(f"Failed to load model: {str(e)}")
+
+
+
+
+ def predict(self, sar_data):
+ """Predict segmentation for new SAR data"""
+ if self.model is None:
+ raise ValueError("Model not trained. Call train() first or load a trained model.")
+
+ # Preprocess input data
+ sar_processed = self.preprocess_sar(sar_data)
+
+ # Ensure correct shape
+ if len(sar_processed.shape) == 3:
+ sar_processed = np.expand_dims(sar_processed, axis=0)
+
+ # Make prediction
+ prediction = self.model.predict(sar_processed)
+ return prediction
+
+ def get_colored_prediction(self, prediction):
+ """Convert prediction to colored image"""
+ pred_class = np.argmax(prediction[0], axis=-1)
+ colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
+
+ for class_idx, color in self.class_colors.items():
+ colored_pred[pred_class == class_idx] = color
+
+ return colored_pred, pred_class
+
+
+
+def visualize_prediction(prediction, original_sar, figsize=(10, 4)):
+ """Visualize segmentation prediction with ESA WorldCover colors"""
+ # ESA WorldCover colors
+ colors = {
+ 0: [0, 100, 0], # Trees - Dark green
+ 1: [255, 165, 0], # Shrubland - Orange
+ 2: [144, 238, 144], # Grassland - Light green
+ 3: [255, 255, 0], # Cropland - Yellow
+ 4: [255, 0, 0], # Built-up - Red
+ 5: [139, 69, 19], # Bare - Brown
+ 6: [255, 255, 255], # Snow - White
+ 7: [0, 0, 255], # Water - Blue
+ 8: [0, 139, 139], # Wetland - Dark cyan
+ 9: [0, 255, 0], # Mangroves - Bright green
+ 10: [220, 220, 220] # Moss - Light grey
+ }
+
+ # Convert prediction to color image
+ pred_class = np.argmax(prediction[0], axis=-1)
+ colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
+
+ for class_idx, color in colors.items():
+ colored_pred[pred_class == class_idx] = color
+
+ # Create overlay
+ sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB)
+ overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0)
+
+ # Create figure
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
+
+ # Original SAR
+ axes[0].imshow(original_sar[:,:,0], cmap='gray')
+ axes[0].set_title('Original SAR', color='white')
+ axes[0].axis('off')
+
+ # Prediction
+ axes[1].imshow(colored_pred)
+ axes[1].set_title('Prediction', color='white')
+ axes[1].axis('off')
+
+ # Overlay
+ axes[2].imshow(overlay)
+ axes[2].set_title('Colorized Output', color='white')
+ axes[2].axis('off')
+
+ # Set dark background
+ fig.patch.set_facecolor('#0a0a1f')
+ for ax in axes:
+ ax.set_facecolor('#0a0a1f')
+
+ plt.tight_layout()
+
+ # Convert plot to image
+ buf = io.BytesIO()
+ plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight')
+ buf.seek(0)
+ plt.close(fig)
+
+ return buf
+
+def create_legend():
+ """Create a legend for the land cover classes"""
+ colors = {
+ 'Trees': [0, 100, 0],
+ 'Shrubland': [255, 165, 0],
+ 'Grassland': [144, 238, 144],
+ 'Cropland': [255, 255, 0],
+ 'Built-up': [255, 0, 0],
+ 'Bare': [139, 69, 19],
+ 'Snow': [255, 255, 255],
+ 'Water': [0, 0, 255],
+ 'Wetland': [0, 139, 139],
+ 'Mangroves': [0, 255, 0],
+ 'Moss': [220, 220, 220]
+ }
+
+ fig, ax = plt.subplots(figsize=(8, 4))
+ fig.patch.set_facecolor('#0a0a1f')
+ ax.set_facecolor('#0a0a1f')
+
+ # Create color patches
+ for i, (class_name, color) in enumerate(colors.items()):
+ ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color]))
+ ax.text(0.7, i + 0.4, class_name, color='white', fontsize=12)
+
+ ax.set_xlim(0, 3)
+ ax.set_ylim(-0.5, len(colors) - 0.5)
+ ax.set_title('Land Cover Classes', color='white', fontsize=14)
+ ax.axis('off')
+
+ buf = io.BytesIO()
+ plt.savefig(buf, format='png', facecolor='#0a0a1f', bbox_inches='tight')
+ buf.seek(0)
+ plt.close(fig)
+
+ return buf
+
+# Initialize session state variables
+if 'model_loaded' not in st.session_state:
+ st.session_state.model_loaded = False
+if 'segmentation' not in st.session_state:
+ st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256)
+if 'processed_images' not in st.session_state:
+ st.session_state.processed_images = []
+
+# Create tabs
+tab1, tab2, tab3 = st.tabs(["📥 Load Model", " Process Single Image", "📁 Process Multiple Images"])
+
+# Add this code in tab1 before the model file uploader
+with tab1:
+ st.markdown("
Load Segmentation Model
", unsafe_allow_html=True)
+
+ # Add model type selection
+ model_type = st.selectbox(
+ "Select model architecture",
+ ["U-Net", "DeepLabV3+", "SegNet"],
+ index=0,
+ help="Select the architecture of the model you're uploading"
+ )
+
+ # Update the model type in the segmentation object
+ st.session_state.segmentation.model_type = model_type.lower().replace('-', '')
+
+ col1, col2 = st.columns([3, 1])
+
+ with col1:
+ model_file = st.file_uploader("Upload model weights (.h5 file)", type=["h5"])
+
+
+ with col2:
+ st.markdown("
", unsafe_allow_html=True)
+ if st.button("Load Model", key="load_model_btn"):
+ if model_file is not None:
+ with st.spinner("Loading model..."):
+ # Save the uploaded model to a temporary file
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.h5') as tmp:
+ tmp.write(model_file.getvalue())
+ model_path = tmp.name
+
+ try:
+ # Load the model
+ st.session_state.segmentation.load_trained_model(model_path)
+ st.session_state.model_loaded = True
+ st.success("Model loaded successfully!")
+ except Exception as e:
+ st.error(f"Error loading model: {str(e)}")
+ finally:
+ # Clean up the temporary file
+ os.unlink(model_path)
+ else:
+ st.error("Please upload a model file (.h5)")
+
+
+ if st.session_state.model_loaded:
+ st.markdown("
", unsafe_allow_html=True)
+ st.markdown("
Model Information
", unsafe_allow_html=True)
+
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.markdown("
", unsafe_allow_html=True)
+ # Display the correct model architecture based on the detected model type
+ model_arch_map = {
+ 'unet': "U-Net",
+ 'deeplabv3plus': "DeepLabV3+",
+ 'segnet': "SegNet"
+ }
+ model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown")
+ st.markdown(f"
{model_arch}
", unsafe_allow_html=True)
+ st.markdown("
Architecture
", unsafe_allow_html=True)
+ st.markdown("
", unsafe_allow_html=True)
+
+
+ with col2:
+ st.markdown("
", unsafe_allow_html=True)
+ st.markdown("
11
", unsafe_allow_html=True)
+ st.markdown("
Land Cover Classes
", unsafe_allow_html=True)
+ st.markdown("
", unsafe_allow_html=True)
+
+ with col3:
+ st.markdown("
", unsafe_allow_html=True)
+ st.markdown("
256 x 256
", unsafe_allow_html=True)
+ st.markdown("
Input Size
", unsafe_allow_html=True)
+ st.markdown("
", unsafe_allow_html=True)
+
+ # Display legend
+ st.markdown("
Land Cover Classes
", unsafe_allow_html=True)
+ legend_img = create_legend()
+ st.image(legend_img, use_column_width=True)
+
+ st.markdown("
", unsafe_allow_html=True)
+ else:
+ st.info("Please load a model to continue.")
+
+
+with tab2:
+ st.markdown("
Process Single SAR Image
", unsafe_allow_html=True)
+
+ if not st.session_state.model_loaded:
+ st.warning("Please load a model in the 'Load Model' tab first.")
+ else:
+ st.markdown("
", unsafe_allow_html=True)
+ uploaded_file = st.file_uploader("Upload a SAR image (.tif or common image formats)", type=["tif", "tiff", "png", "jpg", "jpeg"])
+ st.markdown("
", unsafe_allow_html=True)
+
+ if uploaded_file is not None:
+ if st.button("Process Image", key="process_single_btn"):
+ with st.spinner("Processing image..."):
+ # Load and process the image
+ try:
+ sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True)
+
+ # Normalize for visualization
+ sar_normalized = sar_data.copy()
+ min_val = np.min(sar_normalized)
+ max_val = np.max(sar_normalized)
+ sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8)
+
+ # Make prediction
+ prediction = st.session_state.segmentation.predict(sar_data)
+
+ # Visualize results
+ result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
+
+ # Display results
+ st.markdown("
Segmentation Results
", unsafe_allow_html=True)
+ st.image(result_img, use_column_width=True)
+
+ # Add download button for the result
+ btn = st.download_button(
+ label="Download Result",
+ data=result_img,
+ file_name="segmentation_result.png",
+ mime="image/png",
+ key="download_single_result"
+ )
+ except Exception as e:
+ st.error(f"Error processing image: {str(e)}")
+
+with tab3:
+ st.markdown("
Process Multiple SAR Images
", unsafe_allow_html=True)
+
+ if not st.session_state.model_loaded:
+ st.warning("Please load a model in the 'Load Model' tab first.")
+ else:
+ st.markdown("
", unsafe_allow_html=True)
+ uploaded_files = st.file_uploader("Upload SAR images or a ZIP file containing images", type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], accept_multiple_files=True)
+ st.markdown("
", unsafe_allow_html=True)
+
+ col1, col2 = st.columns([3, 1])
+
+ with col1:
+ max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10)
+
+ with col2:
+ st.markdown("
", unsafe_allow_html=True)
+ process_btn = st.button("Process Images", key="process_multi_btn")
+
+ if process_btn and uploaded_files:
+ # Clear previous results
+ st.session_state.processed_images = []
+
+ # Process uploaded files
+ with st.spinner("Processing images..."):
+ # Create a temporary directory to extract zip files if needed
+ with tempfile.TemporaryDirectory() as temp_dir:
+ # Process each uploaded file
+ image_files = []
+
+ for uploaded_file in uploaded_files:
+ if uploaded_file.name.lower().endswith('.zip'):
+ # Extract zip file
+ zip_path = os.path.join(temp_dir, uploaded_file.name)
+ with open(zip_path, 'wb') as f:
+ f.write(uploaded_file.getvalue())
+
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(temp_dir)
+
+ # Find all image files in the extracted directory
+ for root, _, files in os.walk(temp_dir):
+ for file in files:
+ if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
+ image_files.append(os.path.join(root, file))
+ else:
+ # Save the file to temp directory
+ file_path = os.path.join(temp_dir, uploaded_file.name)
+ with open(file_path, 'wb') as f:
+ f.write(uploaded_file.getvalue())
+ image_files.append(file_path)
+
+ # If there are too many images, randomly select a subset
+ if len(image_files) > max_images:
+ st.info(f"Found {len(image_files)} images. Randomly selecting {max_images} images to display.")
+ image_files = random.sample(image_files, max_images)
+
+ # Process each image
+ progress_bar = st.progress(0)
+ for i, image_path in enumerate(image_files):
+ try:
+ # Update progress
+ progress_bar.progress((i + 1) / len(image_files))
+
+ # Load and process the image
+ sar_data = st.session_state.segmentation.load_sar_data(image_path)
+
+ # Normalize for visualization
+ sar_normalized = sar_data.copy()
+ min_val = np.min(sar_normalized)
+ max_val = np.max(sar_normalized)
+ sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8)
+
+ # Make prediction
+ prediction = st.session_state.segmentation.predict(sar_data)
+
+ # Visualize results
+ result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
+
+ # Add to processed images
+ st.session_state.processed_images.append({
+ 'filename': os.path.basename(image_path),
+ 'result': result_img
+ })
+ except Exception as e:
+ st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}")
+
+ # Clear progress bar
+ progress_bar.empty()
+
+ # Display results
+ if st.session_state.processed_images:
+ st.markdown("
Segmentation Results
", unsafe_allow_html=True)
+
+ # Create a zip file with all results
+ zip_buffer = io.BytesIO()
+ with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
+ for i, img_data in enumerate(st.session_state.processed_images):
+ zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue())
+
+ # Add download button for all results
+ st.download_button(
+ label="Download All Results",
+ data=zip_buffer.getvalue(),
+ file_name="segmentation_results.zip",
+ mime="application/zip",
+ key="download_all_results"
+ )
+
+ # Display each result
+ for i, img_data in enumerate(st.session_state.processed_images):
+ st.markdown(f"
Image: {img_data['filename']}
", unsafe_allow_html=True)
+ st.image(img_data['result'], use_column_width=True)
+ st.markdown("
", unsafe_allow_html=True)
+ else:
+ st.warning("No images were successfully processed.")
+ elif process_btn:
+ st.warning("Please upload at least one image file or ZIP archive.")
+
+# Close the card container
+st.markdown("
", unsafe_allow_html=True)
+
+# Footer
+st.markdown("""
+
+
+ SAR IMAGE COLORIZATION | VARUN & MOKSHYAGNA
+
+
""", unsafe_allow_html=True)
\ No newline at end of file