File size: 8,526 Bytes
eb1aec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
import argparse
import io
import os
import time
import warnings
import azure.storage.blob
import numpy as np
import pandas as pd
import planetary_computer
import pystac_client
import rioxarray # rioxarray is required for the .rio methods in xarray despite what mypy, ruff, etc. says :)
import stackstac
from tqdm import tqdm
def set_up_parser() -> argparse.ArgumentParser:
"""
Set up and return a command-line argument parser for the Sentinel-2 patch downloader.
The parser defines required and optional arguments for specifying the patch download range,
Azure blob storage configuration, and the source GeoParquet file used to sample Sentinel-2 items.
Returns
-------
argparse.ArgumentParser
Configured argument parser with all necessary CLI options.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--low",
default=0,
type=int,
required=True,
help="Starting index",
)
parser.add_argument(
"--high",
default=100_000,
type=int,
required=True,
help="Ending index",
)
parser.add_argument(
"--output_fn",
default="patch_locations.csv",
type=str,
required=True,
help="Output filename",
)
parser.add_argument(
"--storage_account",
type=str,
required=True,
help="Azure storage account URL (e.g. 'https://storageaccount.blob.core.windows.net')",
)
parser.add_argument(
"--container_name",
type=str,
required=True,
help="Azure blob container name",
)
parser.add_argument(
"--sas_key",
type=str,
required=True,
help="SAS key for Azure blob container",
)
parser.add_argument(
"--s2_parquet_fn",
type=str,
required=True,
help="GeoParquet index file to sample from",
)
return parser
def main(args):
"""
Main processing function for downloading Sentinel-2 image patches from a STAC catalog
and uploading them to an Azure Blob container as Cloud-Optimized GeoTIFFs (COGs).
The function selects valid image patches from the input GeoParquet file,
extracts a 256x256 region from a Sentinel-2 STAC item, filters based on NaN content,
and embeds relevant metadata before uploading the patch to Azure Blob Storage.
Parameters
----------
args : argparse.Namespace
Parsed command-line arguments, including input range, output file, Azure credentials,
and STAC sampling source.
"""
# Sanity checks: output file shouldn't already exist, input parquet must exist
assert not os.path.exists(args.output_fn)
assert os.path.exists(args.s2_parquet_fn)
# Set up Azure blob container client
container_client = azure.storage.blob.ContainerClient(
args.storage_account,
container_name=args.container_name,
credential=args.sas_key,
)
# Connect to Microsoft Planetary Computer STAC API
catalog = pystac_client.Client.open(
"https://planetarycomputer.microsoft.com/api/stac/v1/",
modifier=planetary_computer.sign_inplace,
)
collection = catalog.get_collection("sentinel-2-l2a")
# Load input patch candidates from parquet
df = pd.read_parquet(args.s2_parquet_fn)
num_rows = df.shape[0]
# Initialize stats and result tracking
num_retries = 0
num_error_hits = 0
num_empty_hits = 0
progress_bar = tqdm(total=args.high - args.low)
results = []
# Begin sampling loop
idx = args.low
while idx < args.high:
# Select a random row from GeoParquet file
random_row = np.random.randint(0, num_rows)
# Attempt to get this item with progressive exponential backoff
item = None
for j in range(4):
try:
item = collection.get_item(df.iloc[random_row]["id"])
break
except Exception as e:
print(e)
print("retrying", random_row, j)
num_retries += 1
time.sleep(2**j)
if item is None:
print(f"failed to get item {random_row}")
num_error_hits += 1
continue
# Load selected STAC item into a multi-band raster stack
with warnings.catch_warnings():
warnings.simplefilter("ignore")
stack = stackstac.stack(
item,
assets=[
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B11",
"B12",
],
epsg=4326,
)
_, num_channels, height, width = stack.shape
# Randomly sample a 256x256 window within image bounds
x = np.random.randint(0, width - 256)
y = np.random.randint(0, height - 256)
# Extract patch and compute in-memory
patch = stack[0, :, y : y + 256, x : x + 256].compute()
# Filter patches with more than 10% missing data
percent_empty = np.mean((np.isnan(patch.data)).sum(axis=0) == num_channels)
percent_zero = np.mean((patch.data == 0).sum(axis=0) == num_channels)
if percent_empty > 0.1 or percent_zero > 0.1:
num_empty_hits += 1
continue
# Save valid patch to Azure Blob Storage as GeoTIFF with metadata
with io.BytesIO() as buffer:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
patch = patch.astype(np.uint16)
# Extract STAC metadata for provenance and traceability
metadata = {
"datetime": item.datetime.isoformat(),
"platform": item.properties.get("platform", ""),
"mgrs_tile": item.properties.get("s2:mgrs_tile", ""),
"granule_id": item.properties.get("s2:granule_id", ""),
"orbit_state": item.properties.get("sat:orbit_state", ""),
"relative_orbit": str(item.properties.get("sat:relative_orbit", "")),
"cloud_cover": str(item.properties.get("eo:cloud_cover", "")),
"mean_solar_zenith": str(
item.properties.get("s2:mean_solar_zenith", "")
),
"mean_solar_azimuth": str(
item.properties.get("s2:mean_solar_azimuth", "")
),
}
# Attach metadata to the patch for inclusion in the raster tags
patch.attrs.update(metadata)
# Write Cloud-Optimized GeoTIFF to memory
patch.rio.to_raster(
buffer,
driver="GTiff",
dtype=np.uint16,
compress="LZW",
predictor=2,
tiled=True,
blockxsize=256,
blockysize=256,
interleave="pixel",
)
# Upload patch to Azure Blob
buffer.seek(0)
blob_client = container_client.get_blob_client(f"patch_{idx}.tif")
blob_client.upload_blob(buffer, overwrite=True)
# Store patch info for CSV log
results.append(
(
idx,
random_row,
x,
y,
metadata["granule_id"],
)
)
idx += 1
progress_bar.update(1)
progress_bar.close()
# Save all patch locations and sample info to CSV
df = pd.DataFrame(results, columns=["idx", "row", "x", "y", "granule_id"])
df.to_csv(args.output_fn)
# Print final stats
print("Summary:")
print(f"range: [{args.low}, {args.high})")
print(f"num hits: {len(results)}")
print(f"num empty hits: {num_empty_hits}")
print(f"num error hits: {num_error_hits}")
print(f"num retries: {num_retries}")
if __name__ == "__main__":
parser = set_up_parser()
args = parser.parse_args()
main(args)
|