Spaces:
Running
Running
| import argparse | |
| import io | |
| import os | |
| import time | |
| import warnings | |
| import azure.storage.blob | |
| import numpy as np | |
| import pandas as pd | |
| import planetary_computer | |
| import pystac_client | |
| import rioxarray # rioxarray is required for the .rio methods in xarray despite what mypy, ruff, etc. says :) | |
| import stackstac | |
| from tqdm import tqdm | |
| def set_up_parser() -> argparse.ArgumentParser: | |
| """ | |
| Set up and return a command-line argument parser for the Sentinel-2 patch downloader. | |
| The parser defines required and optional arguments for specifying the patch download range, | |
| Azure blob storage configuration, and the source GeoParquet file used to sample Sentinel-2 items. | |
| Returns | |
| ------- | |
| argparse.ArgumentParser | |
| Configured argument parser with all necessary CLI options. | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--low", | |
| default=0, | |
| type=int, | |
| required=True, | |
| help="Starting index", | |
| ) | |
| parser.add_argument( | |
| "--high", | |
| default=100_000, | |
| type=int, | |
| required=True, | |
| help="Ending index", | |
| ) | |
| parser.add_argument( | |
| "--output_fn", | |
| default="patch_locations.csv", | |
| type=str, | |
| required=True, | |
| help="Output filename", | |
| ) | |
| parser.add_argument( | |
| "--storage_account", | |
| type=str, | |
| required=True, | |
| help="Azure storage account URL (e.g. 'https://storageaccount.blob.core.windows.net')", | |
| ) | |
| parser.add_argument( | |
| "--container_name", | |
| type=str, | |
| required=True, | |
| help="Azure blob container name", | |
| ) | |
| parser.add_argument( | |
| "--sas_key", | |
| type=str, | |
| required=True, | |
| help="SAS key for Azure blob container", | |
| ) | |
| parser.add_argument( | |
| "--s2_parquet_fn", | |
| type=str, | |
| required=True, | |
| help="GeoParquet index file to sample from", | |
| ) | |
| return parser | |
| def main(args): | |
| """ | |
| Main processing function for downloading Sentinel-2 image patches from a STAC catalog | |
| and uploading them to an Azure Blob container as Cloud-Optimized GeoTIFFs (COGs). | |
| The function selects valid image patches from the input GeoParquet file, | |
| extracts a 256x256 region from a Sentinel-2 STAC item, filters based on NaN content, | |
| and embeds relevant metadata before uploading the patch to Azure Blob Storage. | |
| Parameters | |
| ---------- | |
| args : argparse.Namespace | |
| Parsed command-line arguments, including input range, output file, Azure credentials, | |
| and STAC sampling source. | |
| """ | |
| # Sanity checks: output file shouldn't already exist, input parquet must exist | |
| assert not os.path.exists(args.output_fn) | |
| assert os.path.exists(args.s2_parquet_fn) | |
| # Set up Azure blob container client | |
| container_client = azure.storage.blob.ContainerClient( | |
| args.storage_account, | |
| container_name=args.container_name, | |
| credential=args.sas_key, | |
| ) | |
| # Connect to Microsoft Planetary Computer STAC API | |
| catalog = pystac_client.Client.open( | |
| "https://planetarycomputer.microsoft.com/api/stac/v1/", | |
| modifier=planetary_computer.sign_inplace, | |
| ) | |
| collection = catalog.get_collection("sentinel-2-l2a") | |
| # Load input patch candidates from parquet | |
| df = pd.read_parquet(args.s2_parquet_fn) | |
| num_rows = df.shape[0] | |
| # Initialize stats and result tracking | |
| num_retries = 0 | |
| num_error_hits = 0 | |
| num_empty_hits = 0 | |
| progress_bar = tqdm(total=args.high - args.low) | |
| results = [] | |
| # Begin sampling loop | |
| idx = args.low | |
| while idx < args.high: | |
| # Select a random row from GeoParquet file | |
| random_row = np.random.randint(0, num_rows) | |
| # Attempt to get this item with progressive exponential backoff | |
| item = None | |
| for j in range(4): | |
| try: | |
| item = collection.get_item(df.iloc[random_row]["id"]) | |
| break | |
| except Exception as e: | |
| print(e) | |
| print("retrying", random_row, j) | |
| num_retries += 1 | |
| time.sleep(2**j) | |
| if item is None: | |
| print(f"failed to get item {random_row}") | |
| num_error_hits += 1 | |
| continue | |
| # Load selected STAC item into a multi-band raster stack | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| stack = stackstac.stack( | |
| item, | |
| assets=[ | |
| "B01", | |
| "B02", | |
| "B03", | |
| "B04", | |
| "B05", | |
| "B06", | |
| "B07", | |
| "B08", | |
| "B8A", | |
| "B09", | |
| "B11", | |
| "B12", | |
| ], | |
| epsg=4326, | |
| ) | |
| _, num_channels, height, width = stack.shape | |
| # Randomly sample a 256x256 window within image bounds | |
| x = np.random.randint(0, width - 256) | |
| y = np.random.randint(0, height - 256) | |
| # Extract patch and compute in-memory | |
| patch = stack[0, :, y : y + 256, x : x + 256].compute() | |
| # Filter patches with more than 10% missing data | |
| percent_empty = np.mean((np.isnan(patch.data)).sum(axis=0) == num_channels) | |
| percent_zero = np.mean((patch.data == 0).sum(axis=0) == num_channels) | |
| if percent_empty > 0.1 or percent_zero > 0.1: | |
| num_empty_hits += 1 | |
| continue | |
| # Save valid patch to Azure Blob Storage as GeoTIFF with metadata | |
| with io.BytesIO() as buffer: | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| patch = patch.astype(np.uint16) | |
| # Extract STAC metadata for provenance and traceability | |
| metadata = { | |
| "datetime": item.datetime.isoformat(), | |
| "platform": item.properties.get("platform", ""), | |
| "mgrs_tile": item.properties.get("s2:mgrs_tile", ""), | |
| "granule_id": item.properties.get("s2:granule_id", ""), | |
| "orbit_state": item.properties.get("sat:orbit_state", ""), | |
| "relative_orbit": str(item.properties.get("sat:relative_orbit", "")), | |
| "cloud_cover": str(item.properties.get("eo:cloud_cover", "")), | |
| "mean_solar_zenith": str( | |
| item.properties.get("s2:mean_solar_zenith", "") | |
| ), | |
| "mean_solar_azimuth": str( | |
| item.properties.get("s2:mean_solar_azimuth", "") | |
| ), | |
| } | |
| # Attach metadata to the patch for inclusion in the raster tags | |
| patch.attrs.update(metadata) | |
| # Write Cloud-Optimized GeoTIFF to memory | |
| patch.rio.to_raster( | |
| buffer, | |
| driver="GTiff", | |
| dtype=np.uint16, | |
| compress="LZW", | |
| predictor=2, | |
| tiled=True, | |
| blockxsize=256, | |
| blockysize=256, | |
| interleave="pixel", | |
| ) | |
| # Upload patch to Azure Blob | |
| buffer.seek(0) | |
| blob_client = container_client.get_blob_client(f"patch_{idx}.tif") | |
| blob_client.upload_blob(buffer, overwrite=True) | |
| # Store patch info for CSV log | |
| results.append( | |
| ( | |
| idx, | |
| random_row, | |
| x, | |
| y, | |
| metadata["granule_id"], | |
| ) | |
| ) | |
| idx += 1 | |
| progress_bar.update(1) | |
| progress_bar.close() | |
| # Save all patch locations and sample info to CSV | |
| df = pd.DataFrame(results, columns=["idx", "row", "x", "y", "granule_id"]) | |
| df.to_csv(args.output_fn) | |
| # Print final stats | |
| print("Summary:") | |
| print(f"range: [{args.low}, {args.high})") | |
| print(f"num hits: {len(results)}") | |
| print(f"num empty hits: {num_empty_hits}") | |
| print(f"num error hits: {num_error_hits}") | |
| print(f"num retries: {num_retries}") | |
| if __name__ == "__main__": | |
| parser = set_up_parser() | |
| args = parser.parse_args() | |
| main(args) | |