ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
import argparse
import io
import os
import time
import warnings
import azure.storage.blob
import numpy as np
import pandas as pd
import planetary_computer
import pystac_client
import rioxarray # rioxarray is required for the .rio methods in xarray despite what mypy, ruff, etc. says :)
import stackstac
from tqdm import tqdm
def set_up_parser() -> argparse.ArgumentParser:
"""
Set up and return a command-line argument parser for the Sentinel-2 patch downloader.
The parser defines required and optional arguments for specifying the patch download range,
Azure blob storage configuration, and the source GeoParquet file used to sample Sentinel-2 items.
Returns
-------
argparse.ArgumentParser
Configured argument parser with all necessary CLI options.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--low",
default=0,
type=int,
required=True,
help="Starting index",
)
parser.add_argument(
"--high",
default=100_000,
type=int,
required=True,
help="Ending index",
)
parser.add_argument(
"--output_fn",
default="patch_locations.csv",
type=str,
required=True,
help="Output filename",
)
parser.add_argument(
"--storage_account",
type=str,
required=True,
help="Azure storage account URL (e.g. 'https://storageaccount.blob.core.windows.net')",
)
parser.add_argument(
"--container_name",
type=str,
required=True,
help="Azure blob container name",
)
parser.add_argument(
"--sas_key",
type=str,
required=True,
help="SAS key for Azure blob container",
)
parser.add_argument(
"--s2_parquet_fn",
type=str,
required=True,
help="GeoParquet index file to sample from",
)
return parser
def main(args):
"""
Main processing function for downloading Sentinel-2 image patches from a STAC catalog
and uploading them to an Azure Blob container as Cloud-Optimized GeoTIFFs (COGs).
The function selects valid image patches from the input GeoParquet file,
extracts a 256x256 region from a Sentinel-2 STAC item, filters based on NaN content,
and embeds relevant metadata before uploading the patch to Azure Blob Storage.
Parameters
----------
args : argparse.Namespace
Parsed command-line arguments, including input range, output file, Azure credentials,
and STAC sampling source.
"""
# Sanity checks: output file shouldn't already exist, input parquet must exist
assert not os.path.exists(args.output_fn)
assert os.path.exists(args.s2_parquet_fn)
# Set up Azure blob container client
container_client = azure.storage.blob.ContainerClient(
args.storage_account,
container_name=args.container_name,
credential=args.sas_key,
)
# Connect to Microsoft Planetary Computer STAC API
catalog = pystac_client.Client.open(
"https://planetarycomputer.microsoft.com/api/stac/v1/",
modifier=planetary_computer.sign_inplace,
)
collection = catalog.get_collection("sentinel-2-l2a")
# Load input patch candidates from parquet
df = pd.read_parquet(args.s2_parquet_fn)
num_rows = df.shape[0]
# Initialize stats and result tracking
num_retries = 0
num_error_hits = 0
num_empty_hits = 0
progress_bar = tqdm(total=args.high - args.low)
results = []
# Begin sampling loop
idx = args.low
while idx < args.high:
# Select a random row from GeoParquet file
random_row = np.random.randint(0, num_rows)
# Attempt to get this item with progressive exponential backoff
item = None
for j in range(4):
try:
item = collection.get_item(df.iloc[random_row]["id"])
break
except Exception as e:
print(e)
print("retrying", random_row, j)
num_retries += 1
time.sleep(2**j)
if item is None:
print(f"failed to get item {random_row}")
num_error_hits += 1
continue
# Load selected STAC item into a multi-band raster stack
with warnings.catch_warnings():
warnings.simplefilter("ignore")
stack = stackstac.stack(
item,
assets=[
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B11",
"B12",
],
epsg=4326,
)
_, num_channels, height, width = stack.shape
# Randomly sample a 256x256 window within image bounds
x = np.random.randint(0, width - 256)
y = np.random.randint(0, height - 256)
# Extract patch and compute in-memory
patch = stack[0, :, y : y + 256, x : x + 256].compute()
# Filter patches with more than 10% missing data
percent_empty = np.mean((np.isnan(patch.data)).sum(axis=0) == num_channels)
percent_zero = np.mean((patch.data == 0).sum(axis=0) == num_channels)
if percent_empty > 0.1 or percent_zero > 0.1:
num_empty_hits += 1
continue
# Save valid patch to Azure Blob Storage as GeoTIFF with metadata
with io.BytesIO() as buffer:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
patch = patch.astype(np.uint16)
# Extract STAC metadata for provenance and traceability
metadata = {
"datetime": item.datetime.isoformat(),
"platform": item.properties.get("platform", ""),
"mgrs_tile": item.properties.get("s2:mgrs_tile", ""),
"granule_id": item.properties.get("s2:granule_id", ""),
"orbit_state": item.properties.get("sat:orbit_state", ""),
"relative_orbit": str(item.properties.get("sat:relative_orbit", "")),
"cloud_cover": str(item.properties.get("eo:cloud_cover", "")),
"mean_solar_zenith": str(
item.properties.get("s2:mean_solar_zenith", "")
),
"mean_solar_azimuth": str(
item.properties.get("s2:mean_solar_azimuth", "")
),
}
# Attach metadata to the patch for inclusion in the raster tags
patch.attrs.update(metadata)
# Write Cloud-Optimized GeoTIFF to memory
patch.rio.to_raster(
buffer,
driver="GTiff",
dtype=np.uint16,
compress="LZW",
predictor=2,
tiled=True,
blockxsize=256,
blockysize=256,
interleave="pixel",
)
# Upload patch to Azure Blob
buffer.seek(0)
blob_client = container_client.get_blob_client(f"patch_{idx}.tif")
blob_client.upload_blob(buffer, overwrite=True)
# Store patch info for CSV log
results.append(
(
idx,
random_row,
x,
y,
metadata["granule_id"],
)
)
idx += 1
progress_bar.update(1)
progress_bar.close()
# Save all patch locations and sample info to CSV
df = pd.DataFrame(results, columns=["idx", "row", "x", "y", "granule_id"])
df.to_csv(args.output_fn)
# Print final stats
print("Summary:")
print(f"range: [{args.low}, {args.high})")
print(f"num hits: {len(results)}")
print(f"num empty hits: {num_empty_hits}")
print(f"num error hits: {num_error_hits}")
print(f"num retries: {num_retries}")
if __name__ == "__main__":
parser = set_up_parser()
args = parser.parse_args()
main(args)