Spaces:
Running on L4
Running on L4
File size: 5,919 Bytes
eb1e3d4 e930f8f eb1e3d4 e930f8f eb1e3d4 f6113e2 eb1e3d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | # ==============================================================================
# 1. Standard Library Imports
# ==============================================================================
import json
import os
from io import BytesIO
# ==============================================================================
# 2. Third-Party Library Imports
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
from affine import Affine # Recommended for working with rasterio transforms
from shapely.geometry import Polygon, Point
import pyproj
from sklearn.decomposition import PCA
from PIL import Image
# Geospatial/Cloud Tools
import rasterio as rio
from rasterio.windows import from_bounds
import s3fs
from fsspec.parquet import open_parquet_file
import pyarrow.parquet as pq
from .grid import *
fs = s3fs.S3FileSystem(anon=True)
print('[DONE]\nReading AEF index from remote location...')
aef_df = gpd.read_parquet('s3://us-west-2.opendata.source.coop/tge-labs/aef/v1/annual/aef_index.parquet')
mt_grid = Grid(10, latitude_range=(-90,90), longitude_range=(-180,180))
# ==============================================================================
# 3. Utility Functions
# ==============================================================================
def row2thumbnail(row):
'''
Read the thumbnail column from a Major TOM metadata parquet row
'''
with open_parquet_file(row.parquet_url,columns = ["thumbnail"]) as f:
with pq.ParquetFile(f) as pf:
first_row_group = pf.read_row_group(row.parquet_row, columns=['thumbnail'])
stream = BytesIO(first_row_group['thumbnail'][0].as_py())
return Image.open(stream)
def rgb_pca(arr: np.ndarray) -> Image.Image:
"""
Applies PCA to a 3D array to reduce its dimensions to 3,
then converts it into a PIL Image.
Args:
arr (np.ndarray): Input array of shape (H, W, C).
Returns:
Image.Image: A PIL Image object ready to be saved.
"""
h,w,c=arr.shape
# Ensure the input array has the correct shape
if len(arr.shape) != 3:
raise ValueError("Input array must be of shape (Height, Width, Channels)")
# Reshape the array for PCA
reshaped_arr = arr.reshape(-1, c)
# Initialize and fit PCA to reduce dimensions to 3
pca = PCA(n_components=3)
pca.fit(reshaped_arr)
# Transform the data using the fitted PCA
transformed_arr = pca.transform(reshaped_arr)
# Normalize the data to the 0-255 range for image conversion
# Min-Max scaling for each principal component
normalized_arr = np.zeros_like(transformed_arr)
for i in range(transformed_arr.shape[1]):
min_val = transformed_arr[:, i].min()
max_val = transformed_arr[:, i].max()
normalized_arr[:, i] = 255 * (transformed_arr[:, i] - min_val) / (max_val - min_val)
# Reshape the array back to its original dimensions (h, w, 3)
# and convert to an 8-bit unsigned integer type
rgb_image_data = (normalized_arr).reshape(h, w, 3).astype(np.uint8)
# Create a PIL Image from the processed NumPy array
#return arr
pil_image = Image.fromarray(rgb_image_data*(arr.sum(-1)!=c*65535)[:,:,None])
return pil_image
def get_product_window(lat, lon, utm_zone=4326, mt_grid_dist = 10, box_size = 10680):
"""
Takes a reference coordinate for top-left corner (lat, lon) of a Major TOM cell
and returns a product footprint for a product in the specified utm_zone (needs to be extracted from a given product)
mt_grid_dist (km) : distance of a given Major TOM grid (10 km is the default)
box_size (m) : length
"""
# offset distributed evenly on both sides
box_offset = (box_size-mt_grid_dist*1000)/2 # metres
if isinstance(utm_zone, int):
utm_crs = f'EPSG:{utm_zone}'
else:
utm_crs = utm_zone
# Define transform
transformer = pyproj.Transformer.from_crs('EPSG:4326', utm_crs, always_xy=True)
# Get corners in UTM coordinates
left,bottom = transformer.transform(lon, lat)
left,bottom = left-box_offset, bottom-box_offset
right,top = left+box_size,bottom+box_size
utm_footprint = Polygon([
(left,bottom),
(right,bottom),
(right,top),
(left,top)
])
return utm_footprint, utm_crs
def cell_to_aef(cell, year, return_centre=False, return_gridcell=False, return_timestamp=False):
row, col = cell.split('_')
# 1. define major tom grid cell
lats, lons = mt_grid.rowcol2latlon([row],[col])
try:
# 2. find the right COG
aef_row = aef_df[aef_df.contains(Point(lons,lats)) & (aef_df.year == year)].iloc[0]
# 3. map lat-lon to utm footprint
window, proj = get_product_window(lats[0], lons[0], utm_zone=aef_row.crs)
# Set the environment variable for GDAL to enable S3 access over HTTPS
# This is crucial for reading COGs directly from S3.
with rio.Env(
aws_unsigned=True,
aws_no_sign_request=True,
GDAL_DISABLE_READ_LOCK='YES',
CPL_VSI_CURL_USE_HEAD='NO'
):
with rio.open(aef_row.path) as src:
minx, miny, maxx, maxy = window.bounds
src_window = from_bounds(minx, maxy, maxx, miny, src.transform) # CAREFUL - COGs here are non-standard (positive y-dir, so require reordering)
arr = src.read(window=src_window).transpose(1,2,0)[::-1,...]
# PCA
pca_img = rgb_pca(arr)
ret = [pca_img.resize((1068,1068))]
if return_centre:
ret.append((lats[0],lons[0]))
if return_gridcell:
ret.append(cell)
if return_timestamp:
ret.append(str(year))
return ret
except:
return None |