VarunRavichander's picture
Update app.py
7db78b6 verified
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
import base64
import tempfile
import zipfile
import random
import time
import rasterio
from rasterio.errors import RasterioIOError
import h5py
import json
# Set page configuration
st.set_page_config(
page_title="SAR Image Colorization",
page_icon="🛰",
layout="wide"
)
def display_image(image_path):
"""Display an image with proper handling for different formats"""
try:
if os.path.exists(image_path):
if image_path.lower().endswith(('.tif', '.tiff')):
# Use rasterio for TIF files
try:
with rasterio.open(image_path) as src:
img_data = src.read(1) # Read first band for single-band images
# For multi-band images
if src.count > 1:
# For RGB images
if src.count >= 3:
img_data = np.dstack([src.read(i) for i in range(1, 4)])
else:
# For 2-band images, duplicate the second band
img_data = np.dstack([src.read(1), src.read(2), src.read(2)])
else:
# For single-band images, create an RGB image
img_data = np.dstack([img_data, img_data, img_data])
# Normalize for display
if img_data.dtype != np.uint8:
img_data = (img_data - np.min(img_data)) / (np.max(img_data) - np.min(img_data)) * 255
img_data = img_data.astype(np.uint8)
st.image(img_data, use_container_width=True)
except Exception as rasterio_error:
# Fall back to PIL
try:
img = Image.open(image_path)
st.image(img, use_container_width=True)
except Exception as pil_error:
st.error(f"Failed to load image: {str(pil_error)}")
else:
# Use PIL for other formats
img = Image.open(image_path)
st.image(img, use_container_width=True)
else:
st.info(f"Image file not found: {image_path}")
except Exception as e:
st.error(f"Error loading image: {str(e)}")
# ==================== UTILITY FUNCTIONS ====================
# GPU setup for SAR to Optical Translation
@st.cache_resource
def setup_gpu():
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
return f"GPU setup complete. Found {len(gpus)} GPU(s)."
return "No GPUs found. Running on CPU."
# ESA WorldCover colors dictionary - used in multiple functions
def get_esa_colors():
return {
0: [0, 100, 0], # Trees - Dark green
1: [255, 165, 0], # Shrubland - Orange
2: [144, 238, 144], # Grassland - Light green
3: [255, 255, 0], # Cropland - Yellow
4: [255, 0, 0], # Built-up - Red
5: [139, 69, 19], # Bare - Brown
6: [255, 255, 255], # Snow - White
7: [0, 0, 255], # Water - Blue
8: [0, 139, 139], # Wetland - Dark cyan
9: [0, 255, 0], # Mangroves - Bright green
10: [220, 220, 220] # Moss - Light grey
}
# When visualizing ground truth, use the same color mapping as for predictions
def visualize_with_ground_truth(sar_image, ground_truth, prediction):
"""Visualize SAR image with ground truth and prediction using ESA WorldCover colors"""
# ESA WorldCover colors
colors = get_esa_colors()
# Convert prediction to color image
pred_class = np.argmax(prediction[0], axis=-1)
colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
for class_idx, color in colors.items():
colored_pred[pred_class == class_idx] = color
# Convert ground truth to color image using the same color scheme
gt_class = ground_truth[:,:,0].astype(np.int32)
# Normalize ground truth to match prediction classes if needed
if np.max(gt_class) > 10: # If using ESA WorldCover values
# Map ESA values to 0-10 indices
gt_mapped = np.zeros_like(gt_class)
class_values = sorted(st.session_state.segmentation.class_definitions.values())
for i, val in enumerate(class_values):
gt_mapped[gt_class == val] = i
gt_class = gt_mapped
colored_gt = np.zeros((gt_class.shape[0], gt_class.shape[1], 3), dtype=np.uint8)
for class_idx, color in colors.items():
colored_gt[gt_class == class_idx] = color
# Create overlay for SAR with prediction
sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2)
# Normalize to 0-255 for visualization
sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8)
overlay = cv2.addWeighted(
sar_rgb,
0.7,
colored_pred,
0.3,
0
)
# Set background color based on theme
bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff'
text_color = 'white' if st.session_state.theme == 'dark' else 'black'
# Create figure
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
# Original SAR
axes[0].imshow(sar_rgb, cmap='gray')
axes[0].set_title('Original SAR', color=text_color)
axes[0].axis('off')
# Ground Truth
axes[1].imshow(colored_gt)
axes[1].set_title('Ground Truth', color=text_color)
axes[1].axis('off')
# Prediction
axes[2].imshow(colored_pred)
axes[2].set_title('Prediction', color=text_color)
axes[2].axis('off')
# Overlay
axes[3].imshow(overlay)
axes[3].set_title('Colorized Output', color=text_color)
axes[3].axis('off')
# Set background color
fig.patch.set_facecolor(bg_color)
for ax in axes:
ax.set_facecolor(bg_color)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight')
buf.seek(0)
plt.close(fig)
return buf, colored_gt, colored_pred, overlay
# Load models for SAR to Optical Translation
@st.cache_resource
def load_models(unet_weights_path, generator_path=None):
# Load U-Net model
unet = get_unet(input_shape=(256, 256, 1), classes=11)
unet.load_weights(unet_weights_path)
# Load generator model if path is provided
generator = None
if generator_path:
try:
generator = tf.keras.models.load_model(generator_path)
except Exception as e:
st.error(f"Error loading generator model: {e}")
return unet, generator
# Preprocess SAR data for SAR to Optical Translation
def preprocess_sar_for_optical(sar_data):
"""Preprocess SAR data"""
# Data is assumed to be in dB scale
sar_clipped = np.clip(sar_data, -50, 20)
sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1
return sar_normalized
# Load SAR image for SAR to Optical Translation
def load_sar_image(file, img_size=(256, 256)):
# Create a temporary file to save the uploaded file
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
tmp_file.write(file.getbuffer())
tmp_file_path = tmp_file.name
try:
with rasterio.open(tmp_file_path) as src:
image = src.read(1)
image = cv2.resize(image, img_size)
image = np.expand_dims(image, axis=-1)
# Preprocess the image
image = preprocess_sar_for_optical(image)
return np.expand_dims(image, axis=0), image
except Exception as e:
st.error(f"Error loading SAR image: {e}")
return None, None
finally:
# Clean up the temporary file
os.unlink(tmp_file_path)
# Process image with models for SAR to Optical Translation
def process_image(sar_image, unet_model, generator_model=None):
# Get segmentation using U-Net
seg_mask = unet_model.predict(sar_image)
# Generate optical using segmentation if generator is available
colorized = None
if generator_model:
colorized = generator_model.predict([sar_image, seg_mask])
colorized = colorized[0]
return seg_mask[0], colorized
# Visualize results for SAR to Optical Translation
def visualize_results(sar_image, seg_mask, colorized=None):
# ESA WorldCover colors
colors = get_esa_colors()
# Convert prediction to color image
pred_class = np.argmax(seg_mask, axis=-1)
colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
for class_idx, color in colors.items():
colored_pred[pred_class == class_idx] = color
# Create overlay
sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2)
# Normalize to 0-255 for visualization
sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8)
overlay = cv2.addWeighted(
sar_rgb,
0.7,
colored_pred,
0.3,
0
)
return sar_rgb, colored_pred, overlay, colorized
# Load model with weights - handles different model loading scenarios
def load_model_with_weights(model_path):
"""Load a model directly from an H5 file, preserving the original architecture"""
# If model_path is a filename without path, prepend the models directory
if not os.path.dirname(model_path) and not model_path.startswith('models/'):
model_path = os.path.join('models', os.path.basename(model_path))
try:
# Try to load the complete model (architecture + weights)
# For Keras 3 compatibility
import tensorflow as tf
keras_version = tf.keras.__version__[0]
if keras_version == '3':
# For Keras 3, try to load with custom_objects to handle compatibility issues
custom_objects = {
'BilinearUpsampling': BilinearUpsampling # Make sure this class is defined
}
model = tf.keras.models.load_model(model_path, compile=False, custom_objects=custom_objects)
else:
# For older Keras versions
model = tf.keras.models.load_model(model_path, compile=False)
print("Loaded complete model with architecture")
return model
except Exception as e:
print(f"Could not load complete model: {str(e)}")
print("Attempting to load just the weights into a matching architecture...")
# Try to inspect the model file to determine architecture
try:
with h5py.File(model_path, 'r') as f:
model_config = None
if 'model_config' in f.attrs:
model_config = json.loads(f.attrs['model_config'].decode('utf-8'))
# If we found a model config, try to recreate it
if model_config:
try:
model = tf.keras.models.model_from_json(json.dumps(model_config))
model.load_weights(model_path)
print("Successfully loaded model from config and weights")
return model
except Exception as e2:
print(f"Failed to load from config: {str(e2)}")
except Exception as e3:
print(f"Failed to inspect model file: {str(e3)}")
# If all else fails, create a new model and try to load weights
try:
# Create a new model based on the model_type in session state
if st.session_state.segmentation.model_type == 'unet':
model = get_unet(
input_shape=(256, 256, 1),
drop_rate=0.3,
classes=11
)
elif st.session_state.segmentation.model_type == 'deeplabv3plus':
model = DeepLabV3Plus(
input_shape=(256, 256, 1),
classes=11
)
elif st.session_state.segmentation.model_type == 'segnet':
model = SegNet(
input_shape=(256, 256, 1),
classes=11
)
# Try to load weights with skip_mismatch
model.load_weights(model_path, by_name=True, skip_mismatch=True)
print("Created new model and loaded compatible weights")
return model
except Exception as e4:
print(f"Failed to create new model and load weights: {str(e4)}")
# If all else fails, return None
return None
# Create a legend for the land cover classes
def create_legend():
"""Create a legend for the land cover classes"""
colors = {
'Trees': [0, 100, 0],
'Shrubland': [255, 165, 0],
'Grassland': [144, 238, 144],
'Cropland': [255, 255, 0],
'Built-up': [255, 0, 0],
'Bare': [139, 69, 19],
'Snow': [255, 255, 255],
'Water': [0, 0, 255],
'Wetland': [0, 139, 139],
'Mangroves': [0, 255, 0],
'Moss': [220, 220, 220]
}
# Set background color based on theme
bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff'
text_color = 'white' if st.session_state.theme == 'dark' else 'black'
fig, ax = plt.subplots(figsize=(8, 4))
fig.patch.set_facecolor(bg_color)
ax.set_facecolor(bg_color)
# Create color patches
for i, (class_name, color) in enumerate(colors.items()):
ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color]))
ax.text(0.7, i + 0.4, class_name, color=text_color, fontsize=12)
ax.set_xlim(0, 3)
ax.set_ylim(-0.5, len(colors) - 0.5)
ax.set_title('Land Cover Classes', color=text_color, fontsize=14)
ax.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight')
buf.seek(0)
plt.close(fig)
return buf
# Visualize segmentation prediction
# Update the visualize_prediction function to support both themes
def visualize_prediction(prediction, original_sar, figsize=(10, 4)):
"""Visualize segmentation prediction with ESA WorldCover colors"""
# ESA WorldCover colors
colors = get_esa_colors()
# Convert prediction to color image
pred_class = np.argmax(prediction[0], axis=-1)
colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
for class_idx, color in colors.items():
colored_pred[pred_class == class_idx] = color
# Create overlay
sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB)
overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0)
# Create figure
fig, axes = plt.subplots(1, 3, figsize=figsize)
# Set background color based on theme
bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff'
text_color = 'white' if st.session_state.theme == 'dark' else 'black'
# Original SAR
axes[0].imshow(original_sar[:,:,0], cmap='gray')
axes[0].set_title('Original SAR', color=text_color)
axes[0].axis('off')
# Prediction
axes[1].imshow(colored_pred)
axes[1].set_title('Prediction', color=text_color)
axes[1].axis('off')
# Overlay
axes[2].imshow(overlay)
axes[2].set_title('Colorized Output', color=text_color)
axes[2].axis('off')
# Set background color
fig.patch.set_facecolor(bg_color)
for ax in axes:
ax.set_facecolor(bg_color)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight')
buf.seek(0)
plt.close(fig)
return buf
# ==================== MODEL DEFINITIONS ====================
# Define the U-Net model
def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11):
inputs = Input(input_shape)
# Encoder
conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
batch1_1 = BatchNormalization()(conv1_1)
conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1)
batch1_2 = BatchNormalization()(conv1_2)
pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2)
conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
batch2_1 = BatchNormalization()(conv2_1)
conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1)
batch2_2 = BatchNormalization()(conv2_2)
pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2)
conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
batch3_1 = BatchNormalization()(conv3_1)
conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1)
batch3_2 = BatchNormalization()(conv3_2)
pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2)
conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
batch4_1 = BatchNormalization()(conv4_1)
conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1)
batch4_2 = BatchNormalization()(conv4_2)
drop4 = Dropout(drop_rate)(batch4_2)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
# Bridge
conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
batch5_1 = BatchNormalization()(conv5_1)
conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1)
batch5_2 = BatchNormalization()(conv5_2)
drop5 = Dropout(drop_rate)(batch5_2)
# Decoder
up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5))
merge6 = concatenate([drop4, up6])
conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
batch6_1 = BatchNormalization()(conv6_1)
conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1)
batch6_2 = BatchNormalization()(conv6_2)
up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2))
merge7 = concatenate([batch3_2, up7])
conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
batch7_1 = BatchNormalization()(conv7_1)
conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1)
batch7_2 = BatchNormalization()(conv7_2)
up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2))
merge8 = concatenate([batch2_2, up8])
conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
batch8_1 = BatchNormalization()(conv8_1)
conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1)
batch8_2 = BatchNormalization()(conv8_2)
up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2))
merge9 = concatenate([batch1_2, up9])
conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
batch9_1 = BatchNormalization()(conv9_1)
conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1)
batch9_2 = BatchNormalization()(conv9_2)
outputs = Conv2D(classes, 1, activation='softmax')(batch9_2)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=Adam(learning_rate=1e-4),
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
# Custom upsampling layer for dynamic resizing
class BilinearUpsampling(Layer):
def __init__(self, size=(1, 1), **kwargs):
super(BilinearUpsampling, self).__init__(**kwargs)
self.size = size
def call(self, inputs):
return tf.image.resize(inputs, self.size, method='bilinear')
def compute_output_shape(self, input_shape):
return (input_shape[0], self.size[0], self.size[1], input_shape[3])
def get_config(self):
config = super(BilinearUpsampling, self).get_config()
config.update({'size': self.size})
return config
# DeepLabV3+ model definition
def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16):
"""
DeepLabV3+ model with Xception backbone
Args:
input_shape: Shape of input images
classes: Number of classes for segmentation
output_stride: Output stride for dilated convolutions (16 or 8)
Returns:
model: DeepLabV3+ model
"""
# Input layer
inputs = Input(input_shape)
# Ensure we're using the right dilation rates based on output_stride
if output_stride == 16:
atrous_rates = (6, 12, 18)
elif output_stride == 8:
atrous_rates = (12, 24, 36)
else:
raise ValueError("Output stride must be 8 or 16")
# === ENCODER (BACKBONE) ===
# Entry block
x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Xception-like blocks with dilated convolutions
# Block 1
residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
x = Add()([x, residual])
# Block 2
residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu')(x)
x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
x = Add()([x, residual])
# Save low_level_features for skip connection (1/4 of input size)
low_level_features = x
# Block 3
residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu')(x)
x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D(3, strides=(2, 2), padding='same')(x)
x = Add()([x, residual])
# Middle flow - modified with dilated convolutions
for i in range(16):
residual = x
x = Activation('relu')(x)
x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x)
x = BatchNormalization()(x)
x = Add()([x, residual])
# Exit flow (modified)
x = Activation('relu')(x)
x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# === ASPP (Atrous Spatial Pyramid Pooling) ===
# 1x1 convolution branch
aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x)
aspp_out1 = BatchNormalization()(aspp_out1)
aspp_out1 = Activation('relu')(aspp_out1)
# 3x3 dilated convolution branches with different rates
aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x)
aspp_out2 = BatchNormalization()(aspp_out2)
aspp_out2 = Activation('relu')(aspp_out2)
aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x)
aspp_out3 = BatchNormalization()(aspp_out3)
aspp_out3 = Activation('relu')(aspp_out3)
aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x)
aspp_out4 = BatchNormalization()(aspp_out4)
aspp_out4 = Activation('relu')(aspp_out4)
# Global pooling branch
# Global pooling branch
aspp_out5 = GlobalAveragePooling2D()(x)
aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) # Use 1024 to match x's channels
aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5)
aspp_out5 = BatchNormalization()(aspp_out5)
aspp_out5 = Activation('relu')(aspp_out5)
# Get current shape of x
_, height, width, _ = tf.keras.backend.int_shape(x)
aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5)
# Concatenate all ASPP branches
aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5])
# Project ASPP output to 256 filters
aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out)
aspp_out = BatchNormalization()(aspp_out)
aspp_out = Activation('relu')(aspp_out)
# === DECODER ===
# Process low-level features from Block 2 (1/4 size)
low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features)
low_level_features = BatchNormalization()(low_level_features)
low_level_features = Activation('relu')(low_level_features)
# Upsample ASPP output by 4x to match low level features size
# Get shapes for verification
low_level_shape = tf.keras.backend.int_shape(low_level_features)
# Upsample to match low_level_features shape
x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1],
low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]),
interpolation='bilinear')(aspp_out)
# Concatenate with low-level features
x = Concatenate()([x, low_level_features])
# Final convolutions
x = Conv2D(256, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, 3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Calculate upsampling size to original input size
x_shape = tf.keras.backend.int_shape(x)
upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2])
# Upsample to original size
x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x)
# Final segmentation output
outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=Adam(learning_rate=1e-4),
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
# SegNet model definition
def SegNet(input_shape=(256, 256, 1), classes=11):
"""
SegNet model for semantic segmentation
Args:
input_shape: Shape of input images
classes: Number of classes for segmentation
Returns:
model: SegNet model
"""
# Input layer
inputs = Input(input_shape)
# === ENCODER ===
# Encoder block 1
x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Regular MaxPooling without indices
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
# Encoder block 2
x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
# Encoder block 3
x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
# Encoder block 4
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
# Encoder block 5
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)
# === DECODER ===
# Using UpSampling2D instead of MaxUnpooling since TensorFlow doesn't support it
# Decoder block 5
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Decoder block 4
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Decoder block 3
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Decoder block 2
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Decoder block 1
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Output layer
outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
return model
# ==================== SAR SEGMENTATION CLASS ====================
class SARSegmentation:
def __init__(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'):
self.img_rows = img_rows
self.img_cols = img_cols
self.drop_rate = drop_rate
self.num_channels = 1 # Single-pol SAR
self.model = None
self.model_type = model_type.lower()
# ESA WorldCover class definitions
self.class_definitions = {
'trees': 10,
'shrubland': 20,
'grassland': 30,
'cropland': 40,
'built_up': 50,
'bare': 60,
'snow': 70,
'water': 80,
'wetland': 90,
'mangroves': 95,
'moss': 100
}
self.num_classes = len(self.class_definitions)
# Class colors for visualization
self.class_colors = get_esa_colors()
def load_sar_data(self, file_path_or_bytes, is_bytes=False):
"""Load SAR data from file path or bytes"""
try:
if is_bytes:
# Create a temporary file to use with rasterio
with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp:
tmp.write(file_path_or_bytes)
tmp_path = tmp.name
try:
with rasterio.open(tmp_path) as src:
sar_data = src.read(1) # Read single band
sar_data = np.expand_dims(sar_data, axis=-1)
except Exception as e:
# If rasterio fails, try PIL
img = Image.open(tmp_path).convert('L')
sar_data = np.array(img)
sar_data = np.expand_dims(sar_data, axis=-1)
# Clean up the temporary file
os.unlink(tmp_path)
else:
try:
with rasterio.open(file_path_or_bytes) as src:
sar_data = src.read(1) # Read single band
sar_data = np.expand_dims(sar_data, axis=-1)
except RasterioIOError:
# Try to open as a regular image if rasterio fails
img = Image.open(file_path_or_bytes).convert('L')
sar_data = np.array(img)
sar_data = np.expand_dims(sar_data, axis=-1)
# Resize if needed
if sar_data.shape[:2] != (self.img_rows, self.img_cols):
sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows))
sar_data = np.expand_dims(sar_data, axis=-1)
return sar_data
except Exception as e:
raise ValueError(f"Failed to load SAR data: {str(e)}")
def preprocess_sar(self, sar_data):
"""Preprocess SAR data"""
# Check if data is already normalized (0-255 range)
if np.max(sar_data) <= 255 and np.min(sar_data) >= 0:
# Normalize to -1 to 1 range
sar_normalized = (sar_data / 127.5) - 1
else:
# Assume it's in dB scale
sar_clipped = np.clip(sar_data, -50, 20)
sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1
return sar_normalized
def one_hot_encode(self, labels):
"""Convert ESA WorldCover labels to one-hot encoded format"""
encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes))
for i, value in enumerate(sorted(self.class_definitions.values())):
encoded[:, :, i] = (labels == value)
return encoded
def load_trained_model(self, model_path):
"""Load a trained model from file"""
try:
# If model_path is a filename without path, prepend the models directory
if not os.path.dirname(model_path) and not model_path.startswith('models/'):
model_path = os.path.join('models', os.path.basename(model_path))
# First try to load the complete model
self.model = load_model_with_weights(model_path)
if self.model is not None:
has_dilated_convs = False
for layer in self.model.layers:
if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'):
if isinstance(layer.dilation_rate, (list, tuple)):
if any(rate > 1 for rate in layer.dilation_rate):
has_dilated_convs = True
break
elif layer.dilation_rate > 1:
has_dilated_convs = True
break
if has_dilated_convs:
self.model_type = 'deeplabv3plus'
print("Detected DeepLabV3+ model")
# Check for SegNet architecture (typically has 5 encoder and 5 decoder blocks)
elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5:
self.model_type = 'segnet'
print("Detected SegNet model")
else:
self.model_type = 'unet'
print("Detected U-Net model")
if self.model is None:
# If that fails, try to create a model with the expected architecture
if self.model_type == 'unet':
self.model = get_unet(
input_shape=(self.img_rows, self.img_cols, self.num_channels),
drop_rate=self.drop_rate,
classes=self.num_classes
)
elif self.model_type == 'deeplabv3plus':
self.model = DeepLabV3Plus(
input_shape=(self.img_rows, self.img_cols, self.num_channels),
classes=self.num_classes
)
elif self.model_type == 'segnet':
self.model = SegNet(
input_shape=(self.img_rows, self.img_cols, self.num_channels),
classes=self.num_classes
)
else:
raise ValueError(f"Model type {self.model_type} not supported")
# Try to load weights, allowing for mismatch
self.model.load_weights(model_path, by_name=True, skip_mismatch=True)
# Check if any weights were loaded
if not any(np.any(w) for w in self.model.get_weights()):
raise ValueError("No weights were loaded. The model architecture is incompatible.")
except Exception as e:
raise ValueError(f"Failed to load model: {str(e)}")
def predict(self, sar_data):
"""Predict segmentation for new SAR data"""
if self.model is None:
raise ValueError("Model not trained. Call train() first or load a trained model.")
# Preprocess input data
sar_processed = self.preprocess_sar(sar_data)
# Ensure correct shape
if len(sar_processed.shape) == 3:
sar_processed = np.expand_dims(sar_processed, axis=0)
# Make prediction
prediction = self.model.predict(sar_processed)
return prediction
def get_colored_prediction(self, prediction):
"""Convert prediction to colored image"""
pred_class = np.argmax(prediction[0], axis=-1)
colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8)
for class_idx, color in self.class_colors.items():
colored_pred[pred_class == class_idx] = color
return colored_pred, pred_class
# ==================== UI SETUP AND STYLING ====================
# Initialize session state variables
# Initialize session state variables
if 'app_mode' not in st.session_state:
st.session_state.app_mode = "SAR Colorization"
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
if 'segmentation' not in st.session_state:
st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256)
if 'processed_images' not in st.session_state:
st.session_state.processed_images = []
if 'theme' not in st.session_state:
st.session_state.theme = "dark" # Default theme
# Apply a single consistent style for the entire app
def set_app_style(app_mode):
if app_mode == "SAR Colorization":
# Dark theme styling for SAR Colorization
st.markdown(
"""
<style>
.stApp {
background-color: #0a0a1f;
color: white;
}
.main {
background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80");
background-size: cover;
background-position: center;
background-repeat: no-repeat;
background-attachment: fixed;
position: relative;
}
.main::before {
content: "";
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(10, 10, 31, 0.7);
backdrop-filter: blur(5px);
z-index: -1;
}
/* Rest of your dark theme CSS */
/* ... */
</style>
""",
unsafe_allow_html=True
)
elif app_mode == "SAR to Optical Translation":
# Light theme styling for SAR to Optical Translation
st.markdown(
"""
<style>
.stApp {
background-color: #f8f9fa;
color: #333;
}
.main {
background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80");
background-size: cover;
background-position: center;
background-repeat: no-repeat;
background-attachment: fixed;
position: relative;
}
.main::before {
content: "";
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(248, 249, 250, 0.7);
backdrop-filter: blur(5px);
z-index: -1;
}
/* Adjust text colors for light theme */
h1, h2, h3, h4, h5, h6 {
color: #333 !important;
}
p, span, div, label {
color: #333 !important;
}
/* Adjust card styling for light theme */
.card {
background-color: rgba(255, 255, 255, 0.7) !important;
border: 1px solid rgba(147, 51, 234, 0.3) !important;
}
/* Adjust metric card styling for light theme */
.metric-card {
background-color: rgba(255, 255, 255, 0.7) !important;
}
.metric-value {
color: #7c3aed !important;
}
.metric-label {
color: #333 !important;
}
/* Rest of your light theme adjustments */
/* ... */
</style>
""",
unsafe_allow_html=True
)
# Create twinkling stars
def create_stars_html(num_stars=100):
stars_html = """<div class="stars">"""
for i in range(num_stars):
size = random.uniform(1, 3)
top = random.uniform(0, 100)
left = random.uniform(0, 100)
duration = random.uniform(3, 8)
opacity = random.uniform(0.2, 0.8)
stars_html += f"""
<div class="star" style="
width: {size}px;
height: {size}px;
top: {top}%;
left: {left}%;
--duration: {duration}s;
--opacity: {opacity};
"></div>
"""
stars_html += "</div>"
return stars_html
# Add logo
def add_logo(logo_path='assets/logo2.png'):
try:
with open(logo_path, "rb") as img_file:
logo_base64 = base64.b64encode(img_file.read()).decode()
st.markdown(
f"""<div style="position: absolute; top: 0.5rem; left: 1rem; z-index: 999;">
<img src="data:image/png;base64,{logo_base64}" width="150px"></div>""",
unsafe_allow_html=True
)
except FileNotFoundError:
st.warning(f"Logo file not found: {logo_path}")
# ==================== MAIN APP LOGIC ====================
# Add stars to background
st.markdown(create_stars_html(), unsafe_allow_html=True)
# Add app mode selector to sidebar
with st.sidebar:
st.image('assets/logo2.png', width=150)
# Add the mode selector
st.title("Applications")
app_mode = st.radio(
"Select Application",
["SAR Colorization", "SAR to Optical Translation"]
)
# Make sure this line is present to update the session state
st.session_state.app_mode = app_mode
# Add theme selector
st.markdown("---")
st.title("Appearance")
theme = st.radio(
"Select Theme",
["Dark", "Light"]
)
set_app_style(st.session_state.app_mode)
# Make sure theme is properly stored in lowercase
if theme.lower() != st.session_state.theme:
st.session_state.theme = theme.lower()
st.rerun() # Force a rerun to apply the new theme
st.markdown("---")
# Sidebar content for SAR Colorization app
if st.session_state.app_mode == "SAR Colorization":
st.title("About")
st.markdown("""
### SAR Image Colorization
This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes.
#### Features:
- Load pre-trained U-Net,DeepLabV3+ or SegNet models
- Process single SAR images
- Batch process multiple images
- Visualize Pixel Level Classification with ESA WorldCover color scheme
#### Developed by:
Varun & Mokshyagna
(NRSC, ISRO)
#### Technologies:
- TensorFlow/Keras
- Streamlit
- Rasterio
- OpenCV
#### Version:
1.0.0
""")
# Sidebar content for the SAR to Optical app
elif st.session_state.app_mode == "SAR to Optical Translation":
st.header("Model Configuration")
# Predefined model paths
unet_weights_path = "models/unet_model.h5"
generator_path = "models/final_generator.keras"
# Display the paths that will be used
st.info(f"U-Net Weights Path: {unet_weights_path}")
use_generator = st.checkbox("Use Generator Model for Colorization", value=True)
if use_generator:
st.info(f"Generator Model Path: {generator_path}")
else:
generator_path = None
# Load models button
if st.button("Load Models"):
with st.spinner("Loading models..."):
gpu_status = setup_gpu()
st.info(gpu_status)
try:
unet_model, generator_model = load_models(unet_weights_path, generator_path if use_generator else None)
st.session_state['unet_model'] = unet_model
st.session_state['generator_model'] = generator_model
st.success("Models loaded successfully!")
except Exception as e:
st.error(f"Error loading models: {e}")
# Class information
st.header("ESA WorldCover Classes")
class_info = {
'Trees': [0, 100, 0],
'Shrubland': [255, 165, 0],
'Grassland': [144, 238, 144],
'Cropland': [255, 255, 0],
'Built-up': [255, 0, 0],
'Bare': [139, 69, 19],
'Snow': [255, 255, 255],
'Water': [0, 0, 255],
'Wetland': [0, 139, 139],
'Mangroves': [0, 255, 0],
'Moss': [220, 220, 220]
}
for class_name, color in class_info.items():
st.markdown(
f'<div style="display: flex; align-items: center;">'
f'<div style="width: 20px; height: 20px; background-color: rgb({color[0]}, {color[1]}, {color[2]}); margin-right: 10px;"></div>'
f'<span>{class_name}</span>'
f'</div>',
unsafe_allow_html=True
)
st.markdown("---")
st.markdown("© 2025 | All Rights Reserved")
# Main content area - conditional rendering based on app mode
if st.session_state.app_mode == "SAR Colorization":
# SAR Colorization app
st.markdown("""
<div style="text-align: center; margin-bottom: 2rem; position: relative; z-index: 100;">
<h1 style="color: #a78bfa; font-size: 3rem; font-weight: bold; text-shadow: 0 0 10px rgba(167, 139, 250, 0.5);">
SAR Image Colorization
</h1>
<p style="color: #bfdbfe; font-size: 1.2rem;">
Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning
</p>
</div>
""", unsafe_allow_html=True)
# Create a card container
st.markdown("<div class='card'>", unsafe_allow_html=True)
# Create tabs
# Create tabs
tab1, tab2, tab3, tab4 = st.tabs(["📥 Load Model", "🖼️ Process Single Image", "📁 Process Multiple Images", "🔍 Sample Images"])
# Tab 1: Load Model
with tab1:
st.markdown("<h3 style='color: #a78bfa;'>Load Segmentation Model</h3>", unsafe_allow_html=True)
# Add model type selection
model_type = st.selectbox(
"Select model architecture",
["U-Net", "DeepLabV3+", "SegNet"],
index=0,
help="Select the architecture of the model to load"
)
# Update the model type in the segmentation object
st.session_state.segmentation.model_type = model_type.lower().replace('-', '')
# Define predefined model paths based on selected architecture
model_paths = {
"unet": "models/unet_model.h5",
"deeplabv3+": "models/deeplabv3plus_model.h5", # Add this key to match the session state
"deeplabv3plus": "models/deeplabv3plus_model.h5", # Keep this as a fallback
"segnet": "models/segnet_model.h5"
}
selected_model_path = model_paths[st.session_state.segmentation.model_type]
# Display the path that will be used
st.info(f"Model will be loaded from: {selected_model_path}")
# Load model button
if st.button("Load Model", key="load_model_btn"):
with st.spinner(f"Loading {model_type} model..."):
try:
# Load the model from the predefined path
st.session_state.segmentation.load_trained_model(selected_model_path)
st.session_state.model_loaded = True
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
# Display model information if loaded
if st.session_state.model_loaded:
st.markdown("<div class='card'>", unsafe_allow_html=True)
st.markdown("<h4 style='color: #a78bfa;'>Model Information</h4>", unsafe_allow_html=True)
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
# Display the correct model architecture based on the detected model type
model_arch_map = {
'unet': "U-Net",
'deeplabv3plus': "DeepLabV3+",
'segnet': "SegNet"
}
model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown")
st.markdown(f"<p class='metric-value'>{model_arch}</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Architecture</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
with col2:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown("<p class='metric-value'>11</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Land Cover Classes</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
with col3:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown("<p class='metric-value'>256 x 256</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Input Size</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# Display legend
st.markdown("<h4 style='color: #a78bfa; margin-top: 20px;'>Land Cover Classes</h4>", unsafe_allow_html=True)
legend_img = create_legend()
st.image(legend_img, use_container_width=True)
st.markdown("</div>", unsafe_allow_html=True)
else:
st.info("Please load a model to continue.")
# Tab 2: Process Single Image
with tab2:
st.markdown("<h3 style='color: #a78bfa;'>Process Single SAR Image</h3>", unsafe_allow_html=True)
if not st.session_state.model_loaded:
st.warning("Please load a model in the 'Load Model' tab first.")
else:
st.markdown("<div class='upload-box'>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
uploaded_file = st.file_uploader(
"Upload a SAR image (.tif or common image formats)",
type=["tif", "tiff", "png", "jpg", "jpeg"],
key="single_sar_uploader"
)
with col2:
# Add ground truth upload option
ground_truth_file = st.file_uploader(
"Upload ground truth (optional)",
type=["tif", "tiff", "png", "jpg", "jpeg"],
key="single_gt_uploader"
)
st.markdown("</div>", unsafe_allow_html=True)
if uploaded_file is not None:
if st.button("Process Image", key="process_single_btn"):
with st.spinner("Processing image..."):
# Load and process the image
try:
sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True)
# Normalize for visualization
sar_normalized = sar_data.copy()
min_val = np.min(sar_normalized)
max_val = np.max(sar_normalized)
sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8)
# Make prediction
prediction = st.session_state.segmentation.predict(sar_data)
# Process ground truth if provided
if ground_truth_file is not None:
try:
# Load ground truth
gt_data = st.session_state.segmentation.load_sar_data(ground_truth_file.getvalue(), is_bytes=True)
# Ensure SAR is properly normalized for visualization
if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0:
# If it's in 0-255 range, normalize to -1 to 1
sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1
else:
# It's already normalized properly
sar_for_viz = sar_normalized
# Create visualization with ground truth using ESA WorldCover colors
result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth(
sar_for_viz,
gt_data,
prediction
)
# Display results with metrics
st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results with Ground Truth</h4>", unsafe_allow_html=True)
st.image(result_buf, use_container_width=True)
# Calculate metrics
pred_class = np.argmax(prediction[0], axis=-1)
gt_class = gt_data[:,:,0].astype(np.int32)
# Normalize ground truth to match prediction classes if needed
if np.max(gt_class) > 10: # If using ESA WorldCover values
# Map ESA values to 0-10 indices
gt_mapped = np.zeros_like(gt_class)
class_values = sorted(st.session_state.segmentation.class_definitions.values())
for i, val in enumerate(class_values):
gt_mapped[gt_class == val] = i
gt_class = gt_mapped
accuracy = np.mean(pred_class == gt_class) * 100
# Display metrics
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{accuracy:.2f}%</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Pixel Accuracy</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# Add download button for the result
btn = st.download_button(
label="Download Result",
data=result_buf,
file_name="segmentation_result_with_gt.png",
mime="image/png",
key="download_single_result_with_gt"
)
except Exception as e:
st.error(f"Error processing ground truth: {str(e)}")
# Fall back to regular visualization without ground truth
result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True)
st.image(result_img, use_container_width=True)
# Add download button for the result
btn = st.download_button(
label="Download Result",
data=result_img,
file_name="segmentation_result.png",
mime="image/png",
key="download_single_result"
)
else:
# Regular visualization without ground truth
result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True)
st.image(result_img, use_container_width=True)
# Add download button for the result
btn = st.download_button(
label="Download Result",
data=result_img,
file_name="segmentation_result.png",
mime="image/png",
key="download_single_result"
)
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Tab 3: Process Multiple Images
with tab3:
st.markdown("<h3 style='color: #a78bfa;'>Process Multiple SAR Images</h3>", unsafe_allow_html=True)
if not st.session_state.model_loaded:
st.warning("Please load a model in the 'Load Model' tab first.")
else:
st.markdown("<div class='upload-box'>", unsafe_allow_html=True)
# Add option for ground truth
use_gt = st.checkbox("Include ground truth data", value=False)
col1, col2 = st.columns(2)
with col1:
uploaded_files = st.file_uploader(
"Upload SAR images or a ZIP file containing images",
type=["tif", "tiff", "png", "jpg", "jpeg", "zip"],
accept_multiple_files=True,
key="batch_sar_uploader"
)
# Add ground truth uploader if option is selected
gt_files = None
if use_gt:
with col2:
gt_files = st.file_uploader(
"Upload ground truth images or a ZIP file (must match SAR filenames)",
type=["tif", "tiff", "png", "jpg", "jpeg", "zip"],
accept_multiple_files=True,
key="batch_gt_uploader"
)
st.info("Ground truth filenames should match SAR image filenames")
st.markdown("</div>", unsafe_allow_html=True)
col1, col2 = st.columns([3, 1])
with col1:
max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10)
with col2:
st.markdown("<br>", unsafe_allow_html=True)
process_btn = st.button("Process Images", key="process_multi_btn")
if process_btn and uploaded_files:
# Clear previous results
st.session_state.processed_images = []
# Process uploaded files
with st.spinner("Processing images..."):
# Create a temporary directory to extract zip files if needed
with tempfile.TemporaryDirectory() as temp_dir:
# Process each uploaded file
sar_image_files = []
gt_image_files = {} # Dictionary to map SAR filenames to GT filenames
# Process SAR files
for uploaded_file in uploaded_files:
if uploaded_file.name.lower().endswith('.zip'):
# Extract zip file
zip_path = os.path.join(temp_dir, uploaded_file.name)
with open(zip_path, 'wb') as f:
f.write(uploaded_file.getvalue())
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(temp_dir, 'sar'))
# Find all image files in the extracted directory
for root, _, files in os.walk(os.path.join(temp_dir, 'sar')):
for file in files:
if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
sar_image_files.append(os.path.join(root, file))
else:
# Save the file to temp directory
file_path = os.path.join(temp_dir, 'sar', uploaded_file.name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f:
f.write(uploaded_file.getvalue())
sar_image_files.append(file_path)
# Process ground truth files if provided
if use_gt and gt_files:
for gt_file in gt_files:
if gt_file.name.lower().endswith('.zip'):
# Extract zip file
zip_path = os.path.join(temp_dir, gt_file.name)
with open(zip_path, 'wb') as f:
f.write(gt_file.getvalue())
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(temp_dir, 'gt'))
# Find all image files in the extracted directory
for root, _, files in os.walk(os.path.join(temp_dir, 'gt')):
for file in files:
if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
# Map GT file to SAR file by filename
gt_path = os.path.join(root, file)
gt_image_files[os.path.basename(file)] = gt_path
else:
# Save the file to temp directory
file_path = os.path.join(temp_dir, 'gt', gt_file.name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f:
f.write(gt_file.getvalue())
gt_image_files[os.path.basename(gt_file.name)] = file_path
# If there are too many images, randomly select a subset
if len(sar_image_files) > max_images:
st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.")
sar_image_files = random.sample(sar_image_files, max_images)
# Process each image
progress_bar = st.progress(0)
# Track overall metrics if ground truth is provided
if use_gt and gt_image_files:
overall_accuracy = []
for i, image_path in enumerate(sar_image_files):
try:
# Update progress
progress_bar.progress((i + 1) / len(sar_image_files))
# Load and process the SAR image
sar_data = st.session_state.segmentation.load_sar_data(image_path)
# Normalize for visualization
sar_normalized = sar_data.copy()
min_val = np.min(sar_normalized)
max_val = np.max(sar_normalized)
sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8)
# Make prediction
prediction = st.session_state.segmentation.predict(sar_data)
# Check if we have a matching ground truth file
image_basename = os.path.basename(image_path)
has_gt = image_basename in gt_image_files
if has_gt and use_gt:
# Load ground truth
gt_path = gt_image_files[image_basename]
gt_data = st.session_state.segmentation.load_sar_data(gt_path)
if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0:
# If it's in 0-255 range, normalize to -1 to 1
sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1
else:
# It's already normalized properly
sar_for_viz = sar_normalized
# Create visualization with ground truth using ESA WorldCover colors
result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth(
sar_for_viz,
gt_data,
prediction
)
# Calculate metrics
pred_class = np.argmax(prediction[0], axis=-1)
gt_class = gt_data[:,:,0].astype(np.int32)
# Normalize ground truth to match prediction classes if needed
if np.max(gt_class) > 10: # If using ESA WorldCover values
# Map ESA values to 0-10 indices
gt_mapped = np.zeros_like(gt_class)
class_values = sorted(st.session_state.segmentation.class_definitions.values())
for i, val in enumerate(class_values):
gt_mapped[gt_class == val] = i
gt_class = gt_mapped
accuracy = np.mean(pred_class == gt_class) * 100
overall_accuracy.append(accuracy)
# Add to processed images with metrics
st.session_state.processed_images.append({
'filename': os.path.basename(image_path),
'result': result_buf,
'accuracy': accuracy
})
else:
# Regular visualization without ground truth
result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
# Add to processed images
st.session_state.processed_images.append({
'filename': os.path.basename(image_path),
'result': result_img
})
except Exception as e:
st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}")
# Clear progress bar
progress_bar.empty()
# Display results
if st.session_state.processed_images:
st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True)
# Display overall metrics if ground truth was provided
if use_gt and 'overall_accuracy' in locals() and overall_accuracy:
avg_accuracy = np.mean(overall_accuracy)
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{avg_accuracy:.2f}%</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Average Pixel Accuracy</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# Create a zip file with all results
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
for i, img_data in enumerate(st.session_state.processed_images):
zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue())
# Add download button for all results
st.download_button(
label="Download All Results",
data=zip_buffer.getvalue(),
file_name="segmentation_results.zip",
mime="application/zip",
key="download_all_results"
)
# Display each result
for i, img_data in enumerate(st.session_state.processed_images):
st.markdown(f"<h5 style='color: #bfdbfe;'>Image: {img_data['filename']}</h5>", unsafe_allow_html=True)
# Display accuracy if available
if 'accuracy' in img_data:
st.markdown(f"<p style='color: #a78bfa;'>Pixel Accuracy: {img_data['accuracy']:.2f}%</p>", unsafe_allow_html=True)
st.image(img_data['result'], use_container_width=True)
st.markdown("<hr style='border-color: rgba(147, 51, 234, 0.3);'>", unsafe_allow_html=True)
else:
st.warning("No images were successfully processed.")
elif process_btn:
st.warning("Please upload at least one image file or ZIP archive.")
# Tab 4: Sample Images
with tab4:
st.markdown("<h3 style='color: #a78bfa;'>Sample Images</h3>", unsafe_allow_html=True)
if not st.session_state.model_loaded:
st.warning("Please load a model in the 'Load Model' tab first.")
else:
st.markdown("<div class='card'>", unsafe_allow_html=True)
# Get list of sample images
import os
sample_dir = "samples/SAR"
if os.path.exists(sample_dir):
sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))]
else:
os.makedirs(sample_dir, exist_ok=True)
os.makedirs("samples/OPTICAL", exist_ok=True)
os.makedirs("samples/LABELS", exist_ok=True)
sample_files = []
if sample_files:
# Create a dropdown to select sample images
selected_sample = st.selectbox(
"Select a sample image",
sample_files,
key="sample_selector"
)
# Display the selected sample
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("SAR Image")
sar_path = os.path.join("samples/SAR", selected_sample)
display_image(sar_path)
with col2:
st.subheader("Optical Image (Ground Truth)")
# Try to find matching optical image
opt_path = os.path.join("samples/OPTICAL", selected_sample)
if os.path.exists(opt_path):
display_image(opt_path)
else:
st.info("No matching optical image found")
#
# Add this debugging code where you're trying to load the label image
with col3:
st.subheader("Label Image")
samples_dir = "samples"
# Try multiple possible label directories
possible_label_dirs = [
os.path.join(samples_dir, "labels"),
os.path.join(samples_dir, "label"),
os.path.join(samples_dir, "LABELS"),
os.path.join(samples_dir, "LABEL"),
os.path.join(samples_dir, "Labels"),
os.path.join(samples_dir, "Label"),
os.path.join(samples_dir, "gt"),
os.path.join(samples_dir, "GT"),
os.path.join(samples_dir, "ground_truth"),
os.path.join(samples_dir, "groundtruth")
]
# Try to find the label file
label_path = None
base_name = os.path.splitext(selected_sample)[0]
# Try different extensions in all possible directories
for dir_path in possible_label_dirs:
if not os.path.exists(dir_path):
continue
# Try exact match first
exact_path = os.path.join(dir_path, selected_sample)
if os.path.exists(exact_path):
label_path = exact_path
break
# Try different extensions
for ext in ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.TIF', '.TIFF', '.PNG', '.JPG', '.JPEG']:
test_path = os.path.join(dir_path, base_name + ext)
if os.path.exists(test_path):
label_path = test_path
break
# Try case-insensitive match
if not label_path:
for file in os.listdir(dir_path):
if os.path.splitext(file)[0].lower() == base_name.lower():
label_path = os.path.join(dir_path, file)
break
if label_path:
break
# Display the label image if found
# Replace the current label display code with this
if label_path and os.path.exists(label_path):
try:
# For ESA WorldCover labels, we need special handling
if label_path.lower().endswith(('.tif', '.tiff')):
with rasterio.open(label_path) as src:
label_data = src.read(1) # Read first band
# Convert ESA WorldCover labels to colored image
colors = get_esa_colors() # This function should be defined in your code
colored_label = np.zeros((label_data.shape[0], label_data.shape[1], 3), dtype=np.uint8)
# Map ESA values to colors
for class_idx, color in colors.items():
# If using ESA WorldCover values (10, 20, 30, etc.)
if np.max(label_data) > 10:
# Map ESA values to 0-10 indices
class_values = sorted(st.session_state.segmentation.class_definitions.values())
for i, val in enumerate(class_values):
if class_idx == i:
colored_label[label_data == val] = color
else:
# Direct mapping if values are already 0-10
colored_label[label_data == class_idx] = color
st.image(colored_label, use_container_width=True)
else:
# For regular image formats
display_image(label_path)
except Exception as e:
st.error(f"Error displaying label image: {str(e)}")
# Fallback to regular display
display_image(label_path)
else:
st.info("No matching label image found")
# Add a button to process the selected sample
if st.button("Process Selected Sample", key="process_sample_btn"):
with st.spinner("Processing sample image..."):
try:
# Load and process the SAR image
sar_data = st.session_state.segmentation.load_sar_data(sar_path)
# Normalize for visualization
sar_normalized = sar_data.copy()
min_val = np.min(sar_normalized)
max_val = np.max(sar_normalized)
sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8)
# Make prediction
prediction = st.session_state.segmentation.predict(sar_data)
# Check if label image exists for comparison
if os.path.exists(label_path):
# Load label image as ground truth
gt_data = st.session_state.segmentation.load_sar_data(label_path)
# Ensure SAR is properly normalized for visualization
if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0:
# If it's in 0-255 range, normalize to -1 to 1
sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1
else:
# It's already normalized properly
sar_for_viz = sar_normalized
# Create visualization with ground truth using ESA WorldCover colors
result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth(
sar_for_viz,
gt_data,
prediction
)
# Display results with metrics
st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results with Ground Truth</h4>", unsafe_allow_html=True)
st.image(result_buf, use_container_width=True)
# Calculate metrics
pred_class = np.argmax(prediction[0], axis=-1)
gt_class = gt_data[:,:,0].astype(np.int32)
# Normalize ground truth to match prediction classes if needed
if np.max(gt_class) > 10: # If using ESA WorldCover values
# Map ESA values to 0-10 indices
gt_mapped = np.zeros_like(gt_class)
class_values = sorted(st.session_state.segmentation.class_definitions.values())
for i, val in enumerate(class_values):
gt_mapped[gt_class == val] = i
gt_class = gt_mapped
accuracy = np.mean(pred_class == gt_class) * 100
# Display metrics
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{accuracy:.2f}%</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Pixel Accuracy</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# Add download button for the result
btn = st.download_button(
label="Download Result",
data=result_buf,
file_name=f"sample_result_{selected_sample}.png",
mime="image/png",
key="download_sample_result_with_gt"
)
else:
# Regular visualization without ground truth
result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1))
st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True)
st.image(result_img, use_container_width=True)
# Add download button for the result
btn = st.download_button(
label="Download Result",
data=result_img,
file_name=f"sample_result_{selected_sample}.png",
mime="image/png",
key="download_sample_result"
)
except Exception as e:
st.error(f"Error processing sample image: {str(e)}")
else:
st.info("No sample images found. Please add some images to the 'samples/SAR' directory.")
# Close the card container
st.markdown("</div>", unsafe_allow_html=True)
elif st.session_state.app_mode == "SAR to Optical Translation":
# SAR to Optical Translation app
st.markdown("""
<div style="text-align: center; margin-bottom: 2rem; position: relative; z-index: 100;">
<h1 style="color: #a78bfa; font-size: 3rem; font-weight: bold; text-shadow: 0 0 10px rgba(167, 139, 250, 0.5);">
SAR to Optical Translation
</h1>
<p style="color: #bfdbfe; font-size: 1.2rem;">
Convert Synthetic Aperture Radar images to optical-like imagery using deep learning
</p>
</div>
""", unsafe_allow_html=True)
# Create a card container
st.markdown("<div class='card'>", unsafe_allow_html=True)
# Check if models are loaded
models_loaded = 'unet_model' in st.session_state
if not models_loaded:
st.warning("Please load the models from the sidebar first.")
else:
st.success("Models loaded successfully! You can now process SAR images.")
# Create tabs for single image and batch processing
# Create tabs for single image and batch processing
tab1, tab2, tab3 = st.tabs(["Process Single Image", "Batch Processing", "Sample Images"])
with tab1:
st.markdown("<h3 style='color: #a78bfa;'>Upload SAR Image</h3>", unsafe_allow_html=True)
# Create two columns for SAR and optional ground truth
# Create two columns for SAR and optional ground truth
col1, col2 = st.columns(2)
with col1:
st.markdown("<div class='upload-box'>", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Upload a SAR image (.tif or common image formats)",
type=["tif", "tiff", "png", "jpg", "jpeg"],
key="sar_optical_uploader"
)
st.markdown("</div>", unsafe_allow_html=True)
# Add ground truth upload option
with col2:
st.markdown("<div class='upload-box'>", unsafe_allow_html=True)
gt_file = st.file_uploader(
"Upload ground truth optical image (optional)",
type=["tif", "tiff", "png", "jpg", "jpeg"],
key="optical_gt_uploader"
)
st.markdown("</div>", unsafe_allow_html=True)
if uploaded_file is not None:
# Process button
if st.button("Generate Optical-like Image", key="generate_optical_btn"):
with st.spinner("Processing image..."):
try:
# Load and process the SAR image
sar_batch, sar_image = load_sar_image(uploaded_file)
if sar_batch is not None:
# Process with models
seg_mask, colorized = process_image(
sar_batch,
st.session_state['unet_model'],
st.session_state.get('generator_model')
)
# Visualize results
sar_rgb, colored_pred, overlay, colorized_img = visualize_results(
sar_image, seg_mask, colorized
)
# Display results
st.header("Results")
# If ground truth is provided, include it in visualization
if gt_file is not None:
try:
# Load ground truth image
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
tmp_file.write(gt_file.getbuffer())
tmp_file_path = tmp_file.name
try:
# Try to open with rasterio
with rasterio.open(tmp_file_path) as src:
gt_image = src.read()
# Debug info
st.info(f"Ground truth shape: {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}")
if gt_image.shape[0] == 3: # RGB image
gt_image = np.transpose(gt_image, (1, 2, 0))
else: # Single band
gt_image = src.read(1)
# Check if the image is all zeros or all ones
if np.all(gt_image == 0) or np.all(gt_image == 1):
st.warning("Ground truth image appears to be blank (all zeros or ones)")
# Convert to RGB for display
gt_image = np.expand_dims(gt_image, axis=-1)
gt_image = np.repeat(gt_image, 3, axis=-1)
except Exception as rasterio_error:
st.warning(f"Rasterio failed: {str(rasterio_error)}. Trying PIL...")
try:
# If rasterio fails, try PIL
gt_image = np.array(Image.open(tmp_file_path).convert('RGB'))
# Debug info
st.info(f"Ground truth shape (PIL): {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}")
# Check if the image is all white
if np.all(gt_image > 250):
st.warning("Ground truth image appears to be all white")
except Exception as pil_error:
st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}")
raise
# Clean up the temporary file
os.unlink(tmp_file_path)
# Resize if needed
if gt_image.shape[:2] != (256, 256):
gt_image = cv2.resize(gt_image, (256, 256))
# Normalize if needed - make sure values are in 0-255 range for display
if gt_image.dtype != np.uint8:
if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255:
gt_image = gt_image.astype(np.uint8)
elif np.max(gt_image) <= 1.0:
gt_image = (gt_image * 255).astype(np.uint8)
else:
# Scale to 0-255
gt_min, gt_max = np.min(gt_image), np.max(gt_image)
gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8)
# Create 4-panel visualization with ground truth
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
# Original SAR
axes[0].imshow(sar_rgb, cmap='gray')
axes[0].set_title('Original SAR', color='white')
axes[0].axis('off')
# Ground Truth
axes[1].imshow(gt_image)
axes[1].set_title('Ground Truth', color='white')
axes[1].axis('off')
# Segmentation
axes[2].imshow(colored_pred)
axes[2].set_title('Segmentation', color='white')
axes[2].axis('off')
# Generated Image
if colorized_img is not None:
# Convert from -1,1 to 0,1 range
colorized_display = (colorized_img * 0.5) + 0.5
axes[3].imshow(colorized_display)
else:
axes[3].imshow(overlay)
axes[3].set_title('Generated Image', color='white')
axes[3].axis('off')
# Set dark background
fig.patch.set_facecolor('#0a0a1f')
for ax in axes:
ax.set_facecolor('#0a0a1f')
plt.tight_layout()
# Display the figure
st.pyplot(fig)
# Calculate metrics if ground truth is provided
if colorized_img is not None:
# Normalize both images to 0-1 range for comparison
colorized_norm = (colorized_img * 0.5) + 0.5
gt_norm = gt_image.astype(np.float32) / 255.0
# Calculate PSNR
mse = np.mean((colorized_norm - gt_norm) ** 2)
psnr = 20 * np.log10(1.0 / np.sqrt(mse))
# Calculate SSIM
from skimage.metrics import structural_similarity as ssim
try:
# Check image dimensions and set appropriate window size
min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1])
win_size = min(7, min_dim - (min_dim % 2) + 1) # Ensure it's odd and smaller than min dimension
ssim_value = ssim(
colorized_norm,
gt_norm,
win_size=win_size, # Explicitly set window size
channel_axis=2, # Specify channel axis for RGB images
data_range=1.0
)
except Exception as e:
st.warning(f"Could not calculate SSIM: {str(e)}")
ssim_value = 0.0 # Default value if calculation fails
# Display metrics
col1, col2 = st.columns(2)
with col1:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{psnr:.2f}</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
with col2:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{ssim_value:.4f}</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
except Exception as e:
st.error(f"Error processing ground truth: {str(e)}")
# Fall back to regular visualization
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("Original SAR Image")
st.image(sar_rgb, use_container_width=True)
with col2:
st.subheader("Predicted Segmentation")
st.image(colored_pred, use_container_width=True)
with col3:
st.subheader("Colorized SAR")
st.image(overlay, use_container_width=True)
else:
# Regular 3-panel visualization without ground truth
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("Original SAR Image")
st.image(sar_rgb, use_container_width=True)
with col2:
st.subheader("Predicted Segmentation")
st.image(colored_pred, use_container_width=True)
with col3:
st.subheader("Colorized SAR")
st.image(overlay, use_container_width=True)
# Display colorized image if available
if colorized_img is not None:
st.header("Translated Optical Image")
# Convert from -1,1 to 0,1 range
colorized_display = (colorized_img * 0.5) + 0.5
# Create a figure with controlled size
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(colorized_display)
ax.axis('off')
# Use the figure for display instead of direct image
st.pyplot(fig, use_container_width=False)
# Add download buttons
col1, col2 = st.columns(2)
with col1:
# Save segmentation image
seg_buf = io.BytesIO()
plt.imsave(seg_buf, colored_pred, format='png')
seg_buf.seek(0)
st.download_button(
label="Download Segmentation",
data=seg_buf,
file_name="segmentation.png",
mime="image/png",
key="download_seg"
)
with col2:
# Save generated image
gen_buf = io.BytesIO()
plt.imsave(gen_buf, colorized_display, format='png')
gen_buf.seek(0)
st.download_button(
label="Download Optical-like Image",
data=gen_buf,
file_name="optical_like.png",
mime="image/png",
key="download_optical"
)
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Batch processing tab
with tab2:
st.markdown("<h3 style='color: #a78bfa;'>Batch Process SAR Images</h3>", unsafe_allow_html=True)
st.markdown("<div class='upload-box'>", unsafe_allow_html=True)
# Add option for ground truth
use_gt = st.checkbox("Include ground truth data", value=False)
col1, col2 = st.columns(2)
with col1:
batch_files = st.file_uploader(
"Upload SAR images or a ZIP file containing images",
type=["tif", "tiff", "png", "jpg", "jpeg", "zip"],
accept_multiple_files=True,
key="batch_sar_optical_uploader"
)
# Add ground truth uploader if option is selected
batch_gt_files = None
if use_gt:
with col2:
batch_gt_files = st.file_uploader(
"Upload ground truth optical images or a ZIP file (must match SAR filenames)",
type=["tif", "tiff", "png", "jpg", "jpeg", "zip"],
accept_multiple_files=True,
key="batch_optical_gt_uploader"
)
st.info("Ground truth filenames should match SAR image filenames")
st.markdown("</div>", unsafe_allow_html=True)
col1, col2 = st.columns([3, 1])
with col1:
max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=5)
with col2:
st.markdown("<br>", unsafe_allow_html=True)
batch_process_btn = st.button("Process Images", key="batch_process_btn")
if batch_process_btn and batch_files:
# Clear previous results
if 'batch_results' not in st.session_state:
st.session_state.batch_results = []
else:
st.session_state.batch_results = []
# Process uploaded files
with st.spinner("Processing images..."):
# Create a temporary directory to extract zip files if needed
with tempfile.TemporaryDirectory() as temp_dir:
# Process each uploaded file
sar_image_files = []
gt_image_files = {} # Dictionary to map SAR filenames to GT filenames
# Process SAR files
for uploaded_file in batch_files:
if uploaded_file.name.lower().endswith('.zip'):
# Extract zip file
zip_path = os.path.join(temp_dir, uploaded_file.name)
with open(zip_path, 'wb') as f:
f.write(uploaded_file.getvalue())
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(temp_dir, 'sar'))
# Find all image files in the extracted directory
for root, _, files in os.walk(os.path.join(temp_dir, 'sar')):
for file in files:
if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
sar_image_files.append(os.path.join(root, file))
else:
# Save the file to temp directory
file_path = os.path.join(temp_dir, 'sar', uploaded_file.name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f:
f.write(uploaded_file.getvalue())
sar_image_files.append(file_path)
# Process ground truth files if provided
if use_gt and batch_gt_files:
for gt_file in batch_gt_files:
if gt_file.name.lower().endswith('.zip'):
# Extract zip file
zip_path = os.path.join(temp_dir, gt_file.name)
with open(zip_path, 'wb') as f:
f.write(gt_file.getvalue())
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(temp_dir, 'gt'))
# Find all image files in the extracted directory
for root, _, files in os.walk(os.path.join(temp_dir, 'gt')):
for file in files:
if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
# Map GT file to SAR file by filename
gt_path = os.path.join(root, file)
gt_image_files[os.path.basename(file)] = gt_path
else:
# Save the file to temp directory
file_path = os.path.join(temp_dir, 'gt', gt_file.name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f:
f.write(gt_file.getvalue())
gt_image_files[os.path.basename(gt_file.name)] = file_path
# If there are too many images, randomly select a subset
if len(sar_image_files) > max_images:
st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.")
sar_image_files = random.sample(sar_image_files, max_images)
# Process each image
progress_bar = st.progress(0)
# Track overall metrics if ground truth is provided
if use_gt and gt_image_files:
overall_psnr = []
overall_ssim = []
for i, image_path in enumerate(sar_image_files):
try:
# Update progress
progress_bar.progress((i + 1) / len(sar_image_files))
# Load and process the SAR image
with open(image_path, 'rb') as f:
file_bytes = f.read()
sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes))
if sar_batch is not None:
# Process with models
seg_mask, colorized = process_image(
sar_batch,
st.session_state['unet_model'],
st.session_state.get('generator_model')
)
# Visualize results
sar_rgb, colored_pred, overlay, colorized_img = visualize_results(
sar_image, seg_mask, colorized
)
# Check if we have a matching ground truth file
image_basename = os.path.basename(image_path)
has_gt = image_basename in gt_image_files
if has_gt and use_gt:
# Load ground truth
gt_path = gt_image_files[image_basename]
try:
# Try to open with rasterio
with rasterio.open(gt_path) as src:
gt_image = src.read()
if gt_image.shape[0] == 3: # RGB image
gt_image = np.transpose(gt_image, (1, 2, 0))
else: # Single band
gt_image = src.read(1)
# Convert to RGB for display
gt_image = np.expand_dims(gt_image, axis=-1)
gt_image = np.repeat(gt_image, 3, axis=-1)
except Exception as rasterio_error:
try:
# If rasterio fails, try PIL
gt_image = np.array(Image.open(gt_path).convert('RGB'))
except Exception as pil_error:
st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}")
raise
# Resize if needed
if gt_image.shape[:2] != (256, 256):
gt_image = cv2.resize(gt_image, (256, 256))
# Normalize if needed - make sure values are in 0-255 range for display
if gt_image.dtype != np.uint8:
if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255:
gt_image = gt_image.astype(np.uint8)
elif np.max(gt_image) <= 1.0:
gt_image = (gt_image * 255).astype(np.uint8)
else:
# Scale to 0-255
gt_min, gt_max = np.min(gt_image), np.max(gt_image)
gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8)
# Create visualization with ground truth
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
# Original SAR
axes[0].imshow(sar_rgb, cmap='gray')
axes[0].set_title('Original SAR', color='white')
axes[0].axis('off')
# Ground Truth
axes[1].imshow(gt_image)
axes[1].set_title('Ground Truth', color='white')
axes[1].axis('off')
# Segmentation
axes[2].imshow(colored_pred)
axes[2].set_title('Segmentation', color='white')
axes[2].axis('off')
# Generated Image
if colorized_img is not None:
# Convert from -1,1 to 0,1 range
colorized_display = (colorized_img * 0.5) + 0.5
axes[3].imshow(colorized_display)
else:
axes[3].imshow(overlay)
axes[3].set_title('Generated Image', color='white')
axes[3].axis('off')
# Set dark background
fig.patch.set_facecolor('#0a0a1f')
for ax in axes:
ax.set_facecolor('#0a0a1f')
plt.tight_layout()
# Convert plot to image
result_buf = io.BytesIO()
plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight')
result_buf.seek(0)
plt.close(fig)
# Calculate metrics if colorized image is available
metrics = {'psnr': 0.0, 'ssim': 0.0} # Default values
if colorized_img is not None:
try:
# Normalize both images to 0-1 range for comparison
colorized_norm = (colorized_img * 0.5) + 0.5
gt_norm = gt_image.astype(np.float32) / 255.0
# Calculate PSNR
mse = np.mean((colorized_norm - gt_norm) ** 2)
if mse > 0:
psnr = 20 * np.log10(1.0 / np.sqrt(mse))
metrics['psnr'] = psnr
overall_psnr.append(psnr)
# Calculate SSIM with explicit window size
from skimage.metrics import structural_similarity as ssim
# Check image dimensions and set appropriate window size
min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1])
win_size = min(7, min_dim - (min_dim % 2) + 1) # Ensure it's odd and smaller than min dimension
ssim_value = ssim(
colorized_norm,
gt_norm,
win_size=win_size, # Explicitly set window size
channel_axis=2, # Specify channel axis for RGB images
data_range=1.0
)
metrics['ssim'] = ssim_value
overall_ssim.append(ssim_value)
except Exception as e:
st.warning(f"Could not calculate metrics for {os.path.basename(image_path)}: {str(e)}")
# Save generated image for download
gen_buf = io.BytesIO()
if colorized_img is not None:
plt.imsave(gen_buf, colorized_display, format='png')
else:
plt.imsave(gen_buf, overlay, format='png')
gen_buf.seek(0)
# Add to batch results
st.session_state.batch_results.append({
'filename': os.path.basename(image_path),
'result': result_buf,
'generated': gen_buf,
'metrics': metrics
})
else:
# Regular visualization without ground truth
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# Original SAR
axes[0].imshow(sar_rgb, cmap='gray')
axes[0].set_title('Original SAR', color='white')
axes[0].axis('off')
# Segmentation
axes[1].imshow(colored_pred)
axes[1].set_title('Segmentation', color='white')
axes[1].axis('off')
# Generated Image
if colorized_img is not None:
# Convert from -1,1 to 0,1 range
colorized_display = (colorized_img * 0.5) + 0.5
axes[2].imshow(colorized_display)
else:
axes[2].imshow(overlay)
axes[2].set_title('Generated Image', color='white')
axes[2].axis('off')
# Set dark background
fig.patch.set_facecolor('#0a0a1f')
for ax in axes:
ax.set_facecolor('#0a0a1f')
plt.tight_layout()
# Convert plot to image
result_buf = io.BytesIO()
plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight')
result_buf.seek(0)
plt.close(fig)
# Save generated image for download
gen_buf = io.BytesIO()
if colorized_img is not None:
plt.imsave(gen_buf, colorized_display, format='png')
else:
plt.imsave(gen_buf, overlay, format='png')
gen_buf.seek(0)
# Add to batch results
st.session_state.batch_results.append({
'filename': os.path.basename(image_path),
'result': result_buf,
'generated': gen_buf
})
except Exception as e:
st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}")
# Clear progress bar
progress_bar.empty()
# Display results
if st.session_state.batch_results:
st.markdown("<h4 style='color: #a78bfa;'>Translation Results</h4>", unsafe_allow_html=True)
# Display overall metrics if ground truth was provided
if use_gt and 'overall_psnr' in locals() and overall_psnr:
avg_psnr = np.mean(overall_psnr)
avg_ssim = np.mean(overall_ssim)
col1, col2 = st.columns(2)
with col1:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{avg_psnr:.2f}</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Average PSNR (dB)</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
with col2:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{avg_ssim:.4f}</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>Average SSIM</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# Create a zip file with all results
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
for i, result in enumerate(st.session_state.batch_results):
# Add visualization
zip_file.writestr(f"result_{i+1}_{result['filename']}.png", result['result'].getvalue())
# Add generated image
zip_file.writestr(f"generated_{i+1}_{result['filename']}.png", result['generated'].getvalue())
# Add download button for all results
st.download_button(
label="Download All Results",
data=zip_buffer.getvalue(),
file_name="translation_results.zip",
mime="application/zip",
key="download_all_translation_results"
)
# Display each result
for i, result in enumerate(st.session_state.batch_results):
st.markdown(f"<h5 style='color: #bfdbfe;'>Image: {result['filename']}</h5>", unsafe_allow_html=True)
# Display metrics if available
if 'metrics' in result:
col1, col2 = st.columns(2)
with col1:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
if 'psnr' in result['metrics']:
st.markdown(f"<p class='metric-value'>{result['metrics']['psnr']:.2f}</p>", unsafe_allow_html=True)
else:
st.markdown("<p class='metric-value'>N/A</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
with col2:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
if 'ssim' in result['metrics']:
st.markdown(f"<p class='metric-value'>{result['metrics']['ssim']:.4f}</p>", unsafe_allow_html=True)
else:
st.markdown("<p class='metric-value'>N/A</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
st.image(result['result'], use_container_width=True)
# Add download button for individual result
col1, col2 = st.columns(2)
with col1:
st.download_button(
label="Download Visualization",
data=result['result'].getvalue(),
file_name=f"result_{result['filename']}.png",
mime="image/png",
key=f"download_viz_{i}"
)
with col2:
st.download_button(
label="Download Generated Image",
data=result['generated'].getvalue(),
file_name=f"generated_{result['filename']}.png",
mime="image/png",
key=f"download_gen_{i}"
)
st.markdown("<hr style='border-color: rgba(147, 51, 234, 0.3);'>", unsafe_allow_html=True)
else:
st.warning("No images were successfully processed.")
elif batch_process_btn:
st.warning("Please upload at least one image file or ZIP archive.")
# Tab 3: Sample Images
with tab3:
st.markdown("<h3 style='color: #a78bfa;'>Sample Images</h3>", unsafe_allow_html=True)
st.markdown("<div class='card'>", unsafe_allow_html=True)
# Get list of sample images
import os
sample_dir = "samples/SAR"
if os.path.exists(sample_dir):
sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))]
else:
os.makedirs(sample_dir, exist_ok=True)
os.makedirs("samples/OPTICAL", exist_ok=True)
os.makedirs("samples/LABELS", exist_ok=True)
sample_files = []
if sample_files and 'unet_model' in st.session_state:
# Create a dropdown to select sample images
selected_sample = st.selectbox(
"Select a sample image",
sample_files,
key="optical_sample_selector"
)
# Display the selected sample
col1, col2 = st.columns(2)
with col1:
st.subheader("SAR Image")
sar_path = os.path.join("samples/SAR", selected_sample)
display_image(sar_path)
with col2:
st.subheader("Optical Image (Ground Truth)")
# Try to find matching optical image
opt_path = os.path.join("samples/OPTICAL", selected_sample)
if os.path.exists(opt_path):
display_image(opt_path)
else:
st.info("No matching optical image found")
# Add a button to process the selected sample
if st.button("Generate Optical-like Image", key="process_optical_sample_btn"):
with st.spinner("Processing sample image..."):
try:
# Load the SAR image
with open(sar_path, 'rb') as f:
file_bytes = f.read()
sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes))
if sar_batch is not None:
# Process with models
seg_mask, colorized = process_image(
sar_batch,
st.session_state['unet_model'],
st.session_state.get('generator_model')
)
# Visualize results
sar_rgb, colored_pred, overlay, colorized_img = visualize_results(
sar_image, seg_mask, colorized
)
# Check if ground truth exists
has_gt = os.path.exists(opt_path)
if has_gt:
# Load ground truth
try:
# Try to open with rasterio
with rasterio.open(opt_path) as src:
gt_image = src.read()
if gt_image.shape[0] == 3: # RGB image
gt_image = np.transpose(gt_image, (1, 2, 0))
else: # Single band
gt_image = src.read(1)
# Convert to RGB for display
# Convert to RGB for display
gt_image = np.expand_dims(gt_image, axis=-1)
gt_image = np.repeat(gt_image, 3, axis=-1)
except Exception as rasterio_error:
try:
# If rasterio fails, try PIL
gt_image = np.array(Image.open(opt_path).convert('RGB'))
except Exception as pil_error:
st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}")
raise
# Resize if needed
if gt_image.shape[:2] != (256, 256):
gt_image = cv2.resize(gt_image, (256, 256))
# Normalize if needed - make sure values are in 0-255 range for display
if gt_image.dtype != np.uint8:
if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255:
gt_image = gt_image.astype(np.uint8)
elif np.max(gt_image) <= 1.0:
gt_image = (gt_image * 255).astype(np.uint8)
else:
# Scale to 0-255
gt_min, gt_max = np.min(gt_image), np.max(gt_image)
gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8)
# Create 4-panel visualization with ground truth
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
# Original SAR
axes[0].imshow(sar_rgb, cmap='gray')
axes[0].set_title('Original SAR', color='white')
axes[0].axis('off')
# Ground Truth
axes[1].imshow(gt_image)
axes[1].set_title('Ground Truth', color='white')
axes[1].axis('off')
# Segmentation
axes[2].imshow(colored_pred)
axes[2].set_title('Segmentation', color='white')
axes[2].axis('off')
# Generated Image
if colorized_img is not None:
# Convert from -1,1 to 0,1 range
colorized_display = (colorized_img * 0.5) + 0.5
axes[3].imshow(colorized_display)
else:
axes[3].imshow(overlay)
axes[3].set_title('Generated Image', color='white')
axes[3].axis('off')
# Set dark background
fig.patch.set_facecolor('#0a0a1f')
for ax in axes:
ax.set_facecolor('#0a0a1f')
plt.tight_layout()
# Display the figure
st.pyplot(fig)
# Calculate metrics if colorized image is available
if colorized_img is not None:
# Normalize both images to 0-1 range for comparison
colorized_norm = (colorized_img * 0.5) + 0.5
gt_norm = gt_image.astype(np.float32) / 255.0
# Calculate PSNR
mse = np.mean((colorized_norm - gt_norm) ** 2)
psnr = 20 * np.log10(1.0 / np.sqrt(mse))
# Calculate SSIM
from skimage.metrics import structural_similarity as ssim
try:
# Check image dimensions and set appropriate window size
min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1])
win_size = min(7, min_dim - (min_dim % 2) + 1) # Ensure it's odd and smaller than min dimension
ssim_value = ssim(
colorized_norm,
gt_norm,
win_size=win_size, # Explicitly set window size
channel_axis=2, # Specify channel axis for RGB images
data_range=1.0
)
except Exception as e:
st.warning(f"Could not calculate SSIM: {str(e)}")
ssim_value = 0.0 # Default value if calculation fails
# Display metrics
col1, col2 = st.columns(2)
with col1:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{psnr:.2f}</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
with col2:
st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
st.markdown(f"<p class='metric-value'>{ssim_value:.4f}</p>", unsafe_allow_html=True)
st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
else:
# Regular 3-panel visualization without ground truth
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("Original SAR Image")
st.image(sar_rgb, use_container_width=True)
with col2:
st.subheader("Predicted Segmentation")
st.image(colored_pred, use_container_width=True)
with col3:
st.subheader("Colorized SAR")
if colorized_img is not None:
# Convert from -1,1 to 0,1 range
colorized_display = (colorized_img * 0.5) + 0.5
st.image(colorized_display, use_container_width=True)
else:
st.image(overlay, use_container_width=True)
# Add download buttons
col1, col2 = st.columns(2)
with col1:
# Save segmentation image
seg_buf = io.BytesIO()
plt.imsave(seg_buf, colored_pred, format='png')
seg_buf.seek(0)
st.download_button(
label="Download Segmentation",
data=seg_buf,
file_name=f"sample_segmentation_{selected_sample}.png",
mime="image/png",
key="download_sample_seg"
)
with col2:
# Save generated image
gen_buf = io.BytesIO()
if colorized_img is not None:
plt.imsave(gen_buf, (colorized_img * 0.5) + 0.5, format='png')
else:
plt.imsave(gen_buf, overlay, format='png')
gen_buf.seek(0)
st.download_button(
label="Download Optical-like Image",
data=gen_buf,
file_name=f"sample_optical_{selected_sample}.png",
mime="image/png",
key="download_sample_optical"
)
except Exception as e:
st.error(f"Error processing sample image: {str(e)}")
elif not sample_files:
st.info("No sample images found. Please add some images to the 'samples/SAR' directory.")
else:
st.warning("Please load the models from the sidebar first.")
# Close the card container
st.markdown("</div>", unsafe_allow_html=True)
# Footer
st.markdown("""
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background-color: rgba(0, 0, 0, 0.3); border-radius: 0.5rem;">
<p style="color: #bfdbfe; font-size: 0.9rem;">
SAR IMAGE PROCESSING | VARUN & MOKSHYAGNA
</p>
</div>
""", unsafe_allow_html=True)
# ==================== UTILITY FUNCTIONS ====================
def create_stars_html():
"""Create twinkling stars effect for background"""
stars_html = """
<div class="stars">
"""
for i in range(100):
size = random.uniform(1, 3)
top = random.uniform(0, 100)
left = random.uniform(0, 100)
duration = random.uniform(3, 8)
opacity = random.uniform(0.2, 0.8)
stars_html += f"""
<div class="star" style="
width: {size}px;
height: {size}px;
top: {top}%;
left: {left}%;
--duration: {duration}s;
--opacity: {opacity};
"></div>
"""
stars_html += "</div>"
return stars_html
# This function is called at the beginning of the app to set up the page
# Update the setup_page_style function to support both themes
def setup_page_style():
"""Set up the page style with CSS based on selected theme"""
# Common CSS for both themes
common_css = """
/* Create twinkling stars effect */
@keyframes twinkle {
0%, 100% { opacity: 0.2; }
50% { opacity: 1; }
}
.stars {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
pointer-events: none;
z-index: -1;
}
.star {
position: absolute;
background-color: white;
border-radius: 50%;
animation: twinkle var(--duration) infinite;
opacity: var(--opacity);
}
/* Tab styling */
.stTabs [data-baseweb="tab-list"] {
gap: 24px !important;
border-radius: 0.5rem;
padding: 0.8rem;
margin-bottom: 3rem !important;
display: flex;
justify-content: center !important;
width: 100%;
}
.stTabs [data-baseweb="tab"] {
height: 5rem !important;
white-space: pre-wrap;
border-radius: 0.5rem;
font-weight: 600 !important;
font-size: 1.6rem !important;
padding: 0 25px !important;
display: flex;
align-items: center;
justify-content: center;
min-width: 200px !important;
}
/* Add more space between tab panels */
.stTabs [data-baseweb="tab-panel"] {
padding-top: 3rem !important;
padding-bottom: 3rem !important;
}
/* Button styling */
.stButton>button {
border: none;
border-radius: 0.5rem;
padding: 0.8rem 1.5rem !important;
font-weight: 500;
font-size: 1.2rem !important;
margin-top: 1.5rem !important;
margin-bottom: 1.5rem !important;
}
/* Spacing */
.element-container {
margin-bottom: 2.5rem !important;
}
h3 {
margin-top: 3rem !important;
margin-bottom: 2rem !important;
font-size: 1.8rem !important;
}
h4 {
margin-top: 2.5rem !important;
margin-bottom: 1.5rem !important;
font-size: 1.5rem !important;
}
h5 {
margin-top: 2rem !important;
margin-bottom: 1.5rem !important;
font-size: 1.3rem !important;
}
img {
margin-top: 1.5rem !important;
margin-bottom: 2.5rem !important;
}
.stProgress > div {
margin-top: 2rem !important;
margin-bottom: 2rem !important;
}
.stSlider {
padding-top: 1.5rem !important;
padding-bottom: 2.5rem !important;
}
.row-widget {
margin-top: 1.5rem !important;
margin-bottom: 2.5rem !important;
}
"""
# Dark theme CSS
dark_css = """
.stApp {
background-color: #0a0a1f;
color: white;
}
.main {
background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80");
background-size: cover;
background-position: center;
background-repeat: no-repeat;
background-attachment: fixed;
position: relative;
}
.main::before {
content: "";
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(10, 10, 31, 0.7);
backdrop-filter: blur(5px);
z-index: -1;
}
/* Title styling */
h1.title {
background: linear-gradient(to right, #a78bfa, #ec4899, #3b82f6);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
color: transparent;
font-size: 3rem !important;
font-weight: bold !important;
text-align: center !important;
margin-bottom: 0.5rem !important;
display: block !important;
position: relative !important;
z-index: 10 !important;
}
p.subtitle {
color: #bfdbfe !important;
font-size: 1.2rem !important;
text-align: center !important;
margin-bottom: 2rem !important;
position: relative !important;
z-index: 10 !important;
}
/* Tab styling */
.stTabs [data-baseweb="tab-list"] {
background-color: rgba(0, 0, 0, 0.3);
}
.stTabs [data-baseweb="tab"] {
background-color: transparent;
color: white;
}
.stTabs [aria-selected="true"] {
background-color: rgba(147, 51, 234, 0.5) !important;
transform: scale(1.05);
transition: all 0.2s ease;
}
/* Card and box styling */
.upload-box {
border: 2px dashed rgba(147, 51, 234, 0.5);
border-radius: 1rem;
padding: 4rem !important;
text-align: center;
margin-bottom: 3rem !important;
}
.card {
background-color: rgba(0, 0, 0, 0.3);
border: 1px solid rgba(147, 51, 234, 0.3);
border-radius: 1rem;
padding: 2.5rem !important;
backdrop-filter: blur(10px);
margin-bottom: 3rem !important;
}
/* Button styling */
.stButton>button {
background: linear-gradient(to right, #7c3aed, #2563eb);
color: white;
}
.stButton>button:hover {
background: linear-gradient(to right, #6d28d9, #1d4ed8);
}
.download-btn {
background-color: #2563eb !important;
}
.stSlider>div>div>div {
background-color: #7c3aed;
}
/* Metrics styling */
.plot-container {
background-color: rgba(0, 0, 0, 0.3);
border-radius: 1rem;
padding: 2rem !important;
margin-bottom: 3rem !important;
}
.metric-card {
background-color: rgba(0, 0, 0, 0.3);
border: 1px solid rgba(147, 51, 234, 0.3);
border-radius: 0.5rem;
padding: 1.5rem !important;
text-align: center;
margin-bottom: 2rem !important;
}
.metric-value {
font-size: 2rem !important;
font-weight: bold;
color: #a78bfa;
}
.metric-label {
font-size: 1.1rem !important;
color: #bfdbfe;
}
/* Form elements */
.stFileUploader > div {
background-color: rgba(0, 0, 0, 0.3) !important;
border: 1px dashed rgba(147, 51, 234, 0.5) !important;
padding: 2rem !important;
margin-bottom: 2rem !important;
}
.stSelectbox > div > div {
background-color: rgba(0, 0, 0, 0.3) !important;
border: 1px solid rgba(147, 51, 234, 0.3) !important;
}
"""
# Light theme CSS
# Light theme CSS - simplified with cream/whitish background and no background image
# Light theme CSS - keeps dark background but uses light text
light_css = """
/* Keep the same dark background */
.stApp {
background-color: #0a0a1f;
}
.main {
background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80");
background-size: cover;
background-position: center;
background-repeat: no-repeat;
background-attachment: fixed;
position: relative;
}
.main::before {
content: "";
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(10, 10, 31, 0.7);
backdrop-filter: blur(5px);
z-index: -1;
}
/* Make all text white/light */
p, span, label, div, h1, h2, h3, h4, h5, h6, li {
color: white !important;
}
/* Title styling - brighter gradient for better visibility */
h1.title {
background: linear-gradient(to right, #d8b4fe, #f9a8d4, #93c5fd);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
color: transparent;
font-size: 3rem !important;
font-weight: bold !important;
text-align: center !important;
margin-bottom: 0.5rem !important;
display: block !important;
position: relative !important;
z-index: 10 !important;
}
p.subtitle {
color: #e0e7ff !important; /* Lighter purple */
font-size: 1.2rem !important;
text-align: center !important;
margin-bottom: 2rem !important;
position: relative !important;
z-index: 10 !important;
}
/* Tab styling - brighter for better visibility */
.stTabs [data-baseweb="tab-list"] {
background-color: rgba(0, 0, 0, 0.3);
}
.stTabs [data-baseweb="tab"] {
background-color: transparent;
color: white !important;
}
.stTabs [aria-selected="true"] {
background-color: rgba(167, 139, 250, 0.5) !important; /* Brighter purple */
transform: scale(1.05);
transition: all 0.2s ease;
}
/* Card and box styling - brighter borders */
.upload-box {
border: 2px dashed rgba(167, 139, 250, 0.7); /* Brighter purple */
border-radius: 1rem;
padding: 4rem !important;
text-align: center;
margin-bottom: 3rem !important;
}
.card {
background-color: rgba(0, 0, 0, 0.3);
border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */
border-radius: 1rem;
padding: 2.5rem !important;
backdrop-filter: blur(10px);
margin-bottom: 3rem !important;
}
/* Button styling - brighter gradient */
.stButton>button {
background: linear-gradient(to right, #a78bfa, #60a5fa);
color: white;
}
.stButton>button:hover {
background: linear-gradient(to right, #8b5cf6, #3b82f6);
}
.download-btn {
background-color: #60a5fa !important;
}
.stSlider>div>div>div {
background-color: #a78bfa;
}
/* Metrics styling - brighter accents */
.plot-container {
background-color: rgba(0, 0, 0, 0.3);
border-radius: 1rem;
padding: 2rem !important;
margin-bottom: 3rem !important;
}
.metric-card {
background-color: rgba(0, 0, 0, 0.3);
border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */
border-radius: 0.5rem;
padding: 1.5rem !important;
text-align: center;
margin-bottom: 2rem !important;
}
.metric-value {
font-size: 2rem !important;
font-weight: bold;
color: #d8b4fe; /* Brighter purple */
}
.metric-label {
font-size: 1.1rem !important;
color: #e0e7ff; /* Lighter purple */
}
/* Form elements - brighter borders */
.stFileUploader > div {
background-color: rgba(0, 0, 0, 0.3) !important;
border: 1px dashed rgba(167, 139, 250, 0.7) !important; /* Brighter purple */
padding: 2rem !important;
margin-bottom: 2rem !important;
}
.stSelectbox > div > div {
background-color: rgba(0, 0, 0, 0.3) !important;
border: 1px solid rgba(167, 139, 250, 0.5) !important; /* Brighter purple */
}
/* Make sure all text inputs have white text */
input, textarea {
color: white !important;
}
/* Ensure sidebar text is white */
.css-1d391kg, .css-1lcbmhc {
color: white !important;
}
/* Make sure plot text is visible on dark background */
.js-plotly-plot .plotly .main-svg text {
fill: white !important;
}
/* Keep stars visible in light theme */
.star {
background-color: white;
opacity: 0.8;
}
/* Make sure all streamlit elements have white text */
.stMarkdown, .stText, .stCode, .stTextInput, .stTextArea, .stSelectbox, .stMultiselect,
.stSlider, .stCheckbox, .stRadio, .stNumber, .stDate, .stTime, .stDateInput, .stTimeInput {
color: white !important;
}
/* Ensure dropdown options are visible */
.stSelectbox ul li {
color: black !important;
}
"""
# Apply the appropriate CSS based on the selected theme
# Apply the appropriate CSS based on the selected theme
if st.session_state.theme == "dark":
st.markdown(f"<style>{common_css}{dark_css}</style>", unsafe_allow_html=True)
else:
st.markdown(f"<style>{common_css}{light_css}</style>", unsafe_allow_html=True)
# ==================== MAIN EXECUTION ====================
if __name__ == "__main__":
# Set up page style
setup_page_style()
# Initialize GPU if available
setup_gpu()
# Initialize session state variables
if 'app_mode' not in st.session_state:
st.session_state.app_mode = "SAR Colorization"
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
if 'segmentation' not in st.session_state:
st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256)
if 'processed_images' not in st.session_state:
st.session_state.processed_images = []