Side-Info_Generation / src /gee_spatial_join.py
mayesh's picture
deploy bird explorer dashboard
2572f0f
Raw
History Blame Contribute Delete
10.3 kB
import os
import time
import logging
import argparse
import numpy as np
import pandas as pd
import ee
import pyarrow.parquet as pq
import pyarrow as pa
from datetime import datetime
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def authenticate_gee(service_account_json):
"""Authenticate to GEE using a service account JSON key file."""
logging.info(f"Authenticating to GEE using {service_account_json}")
try:
credentials = ee.ServiceAccountCredentials(
'', # Earth Engine will extract the email from the JSON
service_account_json
)
ee.Initialize(credentials)
logging.info("GEE Initialization successful.")
except Exception as e:
logging.error(f"GEE Initialization failed: {e}")
raise
def optimize_dataframe(df):
"""Optimize pandas DataFrame memory usage."""
for col in df.columns:
col_type = df[col].dtype
if col_type == 'object':
df[col] = df[col].astype('category')
elif col_type == 'float64':
df[col] = df[col].astype('float32')
elif col_type == 'int64':
df[col] = pd.to_numeric(df[col], downcast='integer')
return df
def execute_with_backoff(gee_task_func, max_retries=5, base_delay=2):
"""Execute a GEE task with exponential backoff to handle rate limits."""
retries = 0
while retries < max_retries:
try:
return gee_task_func()
except ee.EEException as e:
err_str = str(e).lower()
if 'too many requests' in err_str or 'rate limited' in err_str or 'capacity' in err_str or 'timed out' in err_str:
delay = base_delay * (2 ** retries)
logging.warning(f"GEE API error: {e}. Retrying in {delay} seconds (Attempt {retries + 1}/{max_retries})")
time.sleep(delay)
retries += 1
else:
raise e
raise Exception("Max retries exceeded for GEE API.")
def get_gee_features_for_chunk(df_chunk, year, month):
"""
Process a chunk of observations for a specific Year-Month block.
Converts to ee.FeatureCollection and extracts variables using Earth Engine.
"""
# 1. Prepare features
features = []
for idx, row in df_chunk.iterrows():
# Ensure we have valid coordinates
if pd.isna(row['longitude']) or pd.isna(row['latitude']):
continue
# Parse date to millis since epoch for GEE filtering
try:
if isinstance(row['date'], str):
dt = datetime.strptime(row['date'].split(' ')[0], "%Y-%m-%d")
else:
dt = pd.to_datetime(row['date'])
date_ms = int(dt.timestamp() * 1000)
except Exception:
# Fallback if date is malformed
dt = datetime(year, month, 1)
date_ms = int(dt.timestamp() * 1000)
geom = ee.Geometry.Point([row['longitude'], row['latitude']])
feat = ee.Feature(geom, {
'image_id': str(row['image_id']),
'date_ms': date_ms,
'year': year
})
features.append(feat)
if not features:
return []
# Limit chunk size to avoid payload too large errors (GEE limit is ~10MB payload)
# If the chunk is huge (e.g. >5000), we split it. Here we process the given chunk.
fc = ee.FeatureCollection(features)
# 2. Define the extraction function to map over the FeatureCollection
# We use map() because ERA5 and MODIS are time-dependent per exact feature date.
def extract_point_data(feature):
# Feature Date
f_date = ee.Date(feature.getNumber('date_ms'))
f_year = feature.getNumber('year')
geom = feature.geometry()
# A. SRTM Elevation (Static)
srtm = ee.Image('CGIAR/SRTM90_V4')
elev_dict = srtm.reduceRegion(reducer=ee.Reducer.first(), geometry=geom, scale=90)
# B. Copernicus Land Cover (Yearly)
copernicus_ic = ee.ImageCollection('COPERNICUS/Landcover/100m/Proba-V-C3/Global') \
.filter(ee.Filter.calendarRange(f_year, f_year, 'year'))
# Fallback to 2019 if year is beyond available data (2015-2019 available)
lc_image = ee.Algorithms.If(
copernicus_ic.size().gt(0),
copernicus_ic.first(),
ee.ImageCollection('COPERNICUS/Landcover/100m/Proba-V-C3/Global').filterDate('2019-01-01', '2019-12-31').first()
)
lc_dict = ee.Image(lc_image).select('discrete_classification').reduceRegion(reducer=ee.Reducer.first(), geometry=geom, scale=100)
# C. ERA5 Daily (Exact Date)
era5_ic = ee.ImageCollection('ECMWF/ERA5/DAILY') \
.filterDate(f_date, f_date.advance(1, 'day'))
# If available, sample mean_2m_air_temperature and total_precipitation
era5_image = ee.Algorithms.If(
era5_ic.size().gt(0),
era5_ic.first(),
ee.Image.constant(ee.Number(-9999)).rename('mean_2m_air_temperature').addBands(ee.Image.constant(ee.Number(-9999)).rename('total_precipitation'))
)
era5_dict = ee.Image(era5_image).select(['mean_2m_air_temperature', 'total_precipitation']).reduceRegion(reducer=ee.Reducer.first(), geometry=geom, scale=27830)
# D. MODIS NDVI (16-day composite, closest to date)
modis_ic = ee.ImageCollection('MODIS/006/MOD13A2') \
.filterDate(f_date.advance(-16, 'day'), f_date.advance(16, 'day'))
modis_image = ee.Algorithms.If(
modis_ic.size().gt(0),
modis_ic.first(),
ee.Image.constant(ee.Number(-9999)).rename('NDVI')
)
modis_dict = ee.Image(modis_image).select('NDVI').reduceRegion(reducer=ee.Reducer.first(), geometry=geom, scale=1000)
# Combine all dictionaries into feature properties
return feature.set(elev_dict).set(lc_dict).set(era5_dict).set(modis_dict)
# Apply extraction
extracted_fc = fc.map(extract_point_data)
# 3. Retrieve results locally
def fetch_data():
return extracted_fc.getInfo()['features']
raw_results = execute_with_backoff(fetch_data)
# 4. Parse results back to a dictionary format suitable for pandas
processed_data = []
for feat in raw_results:
props = feat.get('properties', {})
processed_data.append({
'image_id': props.get('image_id'),
'elevation': props.get('elevation'),
'landcover_class': props.get('discrete_classification'),
'temperature_2m': props.get('mean_2m_air_temperature'),
'total_precipitation': props.get('total_precipitation'),
'ndvi': props.get('NDVI')
})
return processed_data
def chunker(seq, size):
"""Yield successive chunks from a list/dataframe."""
for pos in range(0, len(seq), size):
yield seq.iloc[pos:pos + size] if isinstance(seq, pd.DataFrame) else seq[pos:pos + size]
def main():
parser = argparse.ArgumentParser(description="Extract environmental features from GEE.")
parser.add_argument('--input_parquet', type=str, required=True, help="Path to base metadata parquet.")
parser.add_argument('--output_parquet', type=str, required=True, help="Path to save enriched parquet.")
parser.add_argument('--service_account_json', type=str, required=True, help="Path to GEE JSON key.")
parser.add_argument('--chunk_size', type=int, default=2500, help="Max features per GEE payload.")
args = parser.parse_args()
authenticate_gee(args.service_account_json)
logging.info(f"Loading {args.input_parquet}")
# Load metadata and optimize memory
df = pd.read_parquet(args.input_parquet)
df = optimize_dataframe(df)
# Ensure date column is datetime to extract year and month
if 'date' in df.columns and not pd.api.types.is_datetime64_any_dtype(df['date']):
df['parsed_date'] = pd.to_datetime(df['date'], errors='coerce')
elif 'date' in df.columns:
df['parsed_date'] = df['date']
else:
raise ValueError("No 'date' column found in parquet file.")
df['year'] = df['parsed_date'].dt.year
df['month'] = df['parsed_date'].dt.month
# Prepare output tracking
all_enriched_records = []
total_processed = 0
# Group by Year and Month as requested
grouped = df.dropna(subset=['year', 'month']).groupby(['year', 'month'])
total_groups = len(grouped)
logging.info(f"Starting GEE extraction across {total_groups} Year-Month blocks.")
for i, ((year, month), group_df) in enumerate(grouped):
year = int(year)
month = int(month)
logging.info(f"Processing Block {i+1}/{total_groups}: {year}-{month:02d} ({len(group_df)} records)")
# Sub-chunking to respect GEE payload limits (CRITICAL for millions of rows)
for chunk_df in chunker(group_df, args.chunk_size):
extracted_chunk = get_gee_features_for_chunk(chunk_df, year, month)
all_enriched_records.extend(extracted_chunk)
total_processed += len(extracted_chunk)
if total_processed % 10000 == 0:
logging.info(f" ...processed {total_processed} records so far.")
logging.info("Extraction complete. Merging with base metadata.")
# Convert extracted list of dicts to DataFrame
df_extracted = pd.DataFrame(all_enriched_records)
# Drop temp columns used for grouping
df.drop(columns=['parsed_date', 'year', 'month'], inplace=True, errors='ignore')
# Ensure image_id types match for merge
df['image_id'] = df['image_id'].astype(str)
df_extracted['image_id'] = df_extracted['image_id'].astype(str)
# Merge based on image_id
df_final = pd.merge(df, df_extracted, on='image_id', how='left')
# Memory optimization before saving
df_final = optimize_dataframe(df_final)
logging.info(f"Saving enriched metadata to {args.output_parquet}")
# Write in chunks if needed, but pyarrow handles large DF writes well
df_final.to_parquet(args.output_parquet, engine='pyarrow', index=False)
logging.info("Done.")
if __name__ == "__main__":
main()