| | """ |
| | Cloud Mask Prediction and Visualization Module |
| | |
| | This script processes Sentinel-2 satellite imagery bands to predict cloud masks |
| | using the omnicloudmask library. It reads blue, red, green, and near-infrared bands, |
| | resamples them as needed, creates a stacked array for prediction, and visualizes |
| | the cloud mask overlaid on the original RGB image. |
| | """ |
| |
|
| | import rasterio |
| | import numpy as np |
| | from rasterio.enums import Resampling |
| | from omnicloudmask import predict_from_array |
| | import matplotlib.pyplot as plt |
| | from matplotlib.colors import ListedColormap |
| | import matplotlib.patches as mpatches |
| |
|
| | def load_band(file_path, resample=False, target_height=None, target_width=None): |
| | """ |
| | Load a single band from a raster file with optional resampling. |
| | |
| | Args: |
| | file_path (str): Path to the raster file |
| | resample (bool): Whether to resample the band |
| | target_height (int, optional): Target height for resampling |
| | target_width (int, optional): Target width for resampling |
| | |
| | Returns: |
| | numpy.ndarray: Band data as float32 array |
| | """ |
| | with rasterio.open(file_path) as src: |
| | if resample and target_height is not None and target_width is not None: |
| | band_data = src.read( |
| | out_shape=(src.count, target_height, target_width), |
| | resampling=Resampling.bilinear |
| | )[0].astype(np.float32) |
| | else: |
| | band_data = src.read()[0].astype(np.float32) |
| | |
| | return band_data |
| |
|
| | def prepare_input_array(base_path="jp2s/"): |
| | """ |
| | Prepare a stacked array of satellite bands for cloud mask prediction. |
| | |
| | This function loads blue, red, green, and near-infrared bands from Sentinel-2 imagery, |
| | resamples the NIR band if needed (from 20m to 10m resolution), and stacks the required |
| | bands for cloud mask prediction in CHW (channel, height, width) format. |
| | |
| | Args: |
| | base_path (str): Base directory containing the JP2 band files |
| | |
| | Returns: |
| | tuple: (stacked_array, rgb_image) |
| | - stacked_array: numpy.ndarray with bands stacked in CHW format for prediction |
| | - rgb_image: numpy.ndarray with RGB bands for visualization |
| | """ |
| | |
| | band_paths = { |
| | 'blue': f"{base_path}B02.jp2", |
| | 'green': f"{base_path}B03.jp2", |
| | 'red': f"{base_path}B04.jp2", |
| | 'nir': f"{base_path}B8A.jp2" |
| | } |
| |
|
| | |
| | with rasterio.open(band_paths['red']) as src: |
| | target_height = src.height |
| | target_width = src.width |
| | |
| | |
| | blue_data = load_band(band_paths['blue']) |
| | green_data = load_band(band_paths['green']) |
| | red_data = load_band(band_paths['red']) |
| | nir_data = load_band( |
| | band_paths['nir'], |
| | resample=True, |
| | target_height=target_height, |
| | target_width=target_width |
| | ) |
| | |
| | |
| | print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}") |
| | |
| | |
| | |
| | scale_factor = 10000.0 |
| | rgb_image = np.stack([ |
| | red_data / scale_factor, |
| | green_data / scale_factor, |
| | blue_data / scale_factor |
| | ], axis=-1) |
| | |
| | |
| | rgb_image = np.clip(rgb_image, 0, 1) |
| | |
| | |
| | prediction_array = np.stack([red_data, green_data, nir_data], axis=0) |
| | |
| | return prediction_array, rgb_image |
| |
|
| | def visualize_cloud_mask(rgb_image, cloud_mask, output_path="cloud_mask_visualization.png"): |
| | """ |
| | Visualize the cloud mask overlaid on the original RGB image. |
| | |
| | Args: |
| | rgb_image (numpy.ndarray): RGB image array (HWC format) |
| | cloud_mask (numpy.ndarray): Predicted cloud mask |
| | output_path (str): Path to save the visualization |
| | """ |
| | |
| | if cloud_mask.ndim > 2: |
| | |
| | print(f"Original cloud mask shape: {cloud_mask.shape}") |
| | cloud_mask = np.squeeze(cloud_mask) |
| | print(f"Squeezed cloud mask shape: {cloud_mask.shape}") |
| | |
| | |
| | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) |
| | |
| | |
| | ax1.imshow(rgb_image) |
| | ax1.set_title("Original RGB Image") |
| | ax1.axis('off') |
| | |
| | |
| | |
| | cloud_cmap = ListedColormap(['green', 'red', 'yellow', 'blue']) |
| | |
| | |
| | im = ax2.imshow(cloud_mask, cmap=cloud_cmap, vmin=0, vmax=3) |
| | ax2.set_title("Cloud Mask") |
| | ax2.axis('off') |
| | |
| | |
| | legend_patches = [ |
| | mpatches.Patch(color='green', label='Clear'), |
| | mpatches.Patch(color='red', label='Thick Cloud'), |
| | mpatches.Patch(color='yellow', label='Thin Cloud'), |
| | mpatches.Patch(color='blue', label='Cloud Shadow') |
| | ] |
| | ax2.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left') |
| | |
| | |
| | ax3.imshow(rgb_image) |
| | |
| | |
| | cloud_mask_rgba = np.zeros((*cloud_mask.shape, 4)) |
| | |
| | |
| | cloud_mask_rgba[cloud_mask == 0] = [0, 1, 0, 0.3] |
| | cloud_mask_rgba[cloud_mask == 1] = [1, 0, 0, 0.5] |
| | cloud_mask_rgba[cloud_mask == 2] = [1, 1, 0, 0.5] |
| | cloud_mask_rgba[cloud_mask == 3] = [0, 0, 1, 0.5] |
| | |
| | ax3.imshow(cloud_mask_rgba) |
| | ax3.set_title("RGB with Cloud Mask Overlay") |
| | ax3.axis('off') |
| | |
| | |
| | ax3.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left') |
| | |
| | |
| | plt.tight_layout() |
| | plt.savefig(output_path, dpi=300, bbox_inches='tight') |
| | plt.show() |
| | |
| | print(f"Visualization saved to {output_path}") |
| |
|
| | def main(): |
| | """ |
| | Main function to run the cloud mask prediction and visualization workflow. |
| | """ |
| | |
| | input_array, rgb_image = prepare_input_array() |
| | |
| | |
| | pred_mask = predict_from_array(input_array) |
| | |
| | |
| | print("Cloud mask prediction results:") |
| | print(f"Cloud mask shape: {pred_mask.shape}") |
| | print(f"Unique classes in mask: {np.unique(pred_mask)}") |
| | |
| | |
| | if pred_mask.ndim > 2: |
| | |
| | flat_mask = np.squeeze(pred_mask) |
| | else: |
| | flat_mask = pred_mask |
| | |
| | print(f"Class distribution: Clear: {np.sum(flat_mask == 0)}, Thick Cloud: {np.sum(flat_mask == 1)}, " |
| | f"Thin Cloud: {np.sum(flat_mask == 2)}, Cloud Shadow: {np.sum(flat_mask == 3)}") |
| | |
| | |
| | visualize_cloud_mask(rgb_image, pred_mask) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|