File size: 5,919 Bytes
eb1e3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e930f8f
 
eb1e3d4
 
 
e930f8f
eb1e3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6113e2
eb1e3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# ==============================================================================
# 1. Standard Library Imports
# ==============================================================================
import json
import os
from io import BytesIO

# ==============================================================================
# 2. Third-Party Library Imports
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
from affine import Affine # Recommended for working with rasterio transforms
from shapely.geometry import Polygon, Point
import pyproj
from sklearn.decomposition import PCA
from PIL import Image

# Geospatial/Cloud Tools
import rasterio as rio
from rasterio.windows import from_bounds
import s3fs
from fsspec.parquet import open_parquet_file
import pyarrow.parquet as pq

from .grid import *

fs = s3fs.S3FileSystem(anon=True)
print('[DONE]\nReading AEF index from remote location...')
aef_df = gpd.read_parquet('s3://us-west-2.opendata.source.coop/tge-labs/aef/v1/annual/aef_index.parquet')
mt_grid = Grid(10, latitude_range=(-90,90), longitude_range=(-180,180))

# ==============================================================================
# 3. Utility Functions
# ==============================================================================

def row2thumbnail(row):
    '''
        Read the thumbnail column from a Major TOM metadata parquet row
    '''

    with open_parquet_file(row.parquet_url,columns = ["thumbnail"]) as f:
        with pq.ParquetFile(f) as pf:
            first_row_group = pf.read_row_group(row.parquet_row, columns=['thumbnail'])

    stream = BytesIO(first_row_group['thumbnail'][0].as_py())
    return Image.open(stream)

def rgb_pca(arr: np.ndarray) -> Image.Image:
    """
    Applies PCA to a 3D array to reduce its dimensions to 3,
    then converts it into a PIL Image.

    Args:
        arr (np.ndarray): Input array of shape (H, W, C).

    Returns:
        Image.Image: A PIL Image object ready to be saved.
    """
    h,w,c=arr.shape

    # Ensure the input array has the correct shape
    if len(arr.shape) != 3:
        raise ValueError("Input array must be of shape (Height, Width, Channels)")

    # Reshape the array for PCA
    reshaped_arr = arr.reshape(-1, c)

    # Initialize and fit PCA to reduce dimensions to 3
    pca = PCA(n_components=3)
    pca.fit(reshaped_arr)

    # Transform the data using the fitted PCA
    transformed_arr = pca.transform(reshaped_arr)

    # Normalize the data to the 0-255 range for image conversion
    # Min-Max scaling for each principal component
    normalized_arr = np.zeros_like(transformed_arr)
    for i in range(transformed_arr.shape[1]):
        min_val = transformed_arr[:, i].min()
        max_val = transformed_arr[:, i].max()
        normalized_arr[:, i] = 255 * (transformed_arr[:, i] - min_val) / (max_val - min_val)

    # Reshape the array back to its original dimensions (h, w, 3)
    # and convert to an 8-bit unsigned integer type
    rgb_image_data = (normalized_arr).reshape(h, w, 3).astype(np.uint8)

    # Create a PIL Image from the processed NumPy array
    #return arr
    pil_image = Image.fromarray(rgb_image_data*(arr.sum(-1)!=c*65535)[:,:,None])

    return pil_image

def get_product_window(lat, lon, utm_zone=4326, mt_grid_dist = 10, box_size = 10680):
    """
        Takes a reference coordinate for top-left corner (lat, lon) of a Major TOM cell
        and returns a product footprint for a product in the specified utm_zone (needs to be extracted from a given product)


        mt_grid_dist (km) : distance of a given Major TOM grid (10 km is the default)
        box_size (m) : length
    """
    # offset distributed evenly on both sides
    box_offset = (box_size-mt_grid_dist*1000)/2 # metres

    if isinstance(utm_zone, int):
        utm_crs = f'EPSG:{utm_zone}'
    else:
        utm_crs = utm_zone

    # Define transform
    transformer = pyproj.Transformer.from_crs('EPSG:4326', utm_crs, always_xy=True)

    # Get corners in UTM coordinates
    left,bottom = transformer.transform(lon, lat)
    left,bottom = left-box_offset, bottom-box_offset
    right,top = left+box_size,bottom+box_size

    utm_footprint = Polygon([
        (left,bottom),
        (right,bottom),
        (right,top),
        (left,top)
    ])

    return utm_footprint, utm_crs

def cell_to_aef(cell, year, return_centre=False, return_gridcell=False, return_timestamp=False):

    row, col = cell.split('_')
    
    # 1. define major tom grid cell
    lats, lons = mt_grid.rowcol2latlon([row],[col])

    try:
        # 2. find the right COG
        aef_row = aef_df[aef_df.contains(Point(lons,lats)) & (aef_df.year == year)].iloc[0]
        
        # 3. map lat-lon to utm footprint
        window, proj = get_product_window(lats[0], lons[0], utm_zone=aef_row.crs)
    
        # Set the environment variable for GDAL to enable S3 access over HTTPS
        # This is crucial for reading COGs directly from S3.
        with rio.Env(
            aws_unsigned=True,
            aws_no_sign_request=True,
            GDAL_DISABLE_READ_LOCK='YES',
            CPL_VSI_CURL_USE_HEAD='NO'
        ):
            with rio.open(aef_row.path) as src:
                minx, miny, maxx, maxy = window.bounds
                src_window = from_bounds(minx, maxy, maxx, miny, src.transform) # CAREFUL - COGs here are non-standard (positive y-dir, so require reordering)
                arr = src.read(window=src_window).transpose(1,2,0)[::-1,...]
    
        # PCA
        pca_img = rgb_pca(arr)
    
        ret = [pca_img.resize((1068,1068))]
        if return_centre:
            ret.append((lats[0],lons[0]))
        if return_gridcell:
            ret.append(cell)
        if return_timestamp:
            ret.append(str(year))
    
        return ret
    except:
        return None