predictive_irrigation_models / pipelines /resample_impute_pipeline.py
paolog-fbk's picture
Upload folder using huggingface_hub
64ab846 verified
from prefect import task
import numpy as np
import pandas as pd
from prefect.logging import get_run_logger
import yaml
import xarray as xr
import warnings
from scipy.interpolate import interp1d
import pyproj
from datetime import timedelta
from pathlib import Path
from typing import Literal, Optional, Dict, Union
from tools.gis_utils import _get_field_pixel_indices, _get_closest_pixel_indices, _infer_dataset_crs, _load_field_mapping
with open('config/params.yml') as file:
config = yaml.safe_load(file)
datetime_col = config['datetime_col']
datastream_id_col = config['datastream_id_col']
datastream_name_col = config['datastream_name_col']
resampling_window = config['resampling_window']
days_weather_forecast = config['days_weather_forecast']
### Helper functions
@task(task_run_name='interpolate_satellite_data')
def interpolate_satellite_data(remote_sensing_data):
"""
Efficiently interpolates satellite data across all bands using cloud mask.
Parameters:
-----------
remote_sensing_data : xarray.Dataset / xarray.DataArray
Dataset with dimensions (time, band, y, x) where band 0 is cloud cover
Must have a 'time' coordinate (from the previous pipeline)
Returns:
--------
xarray.Dataset
Interpolated dataset with same structure as input, with complete time series at specified frequency
"""
# init logger
logger = get_run_logger()
# get dates from the xarray time coordinate
dates_dt = pd.to_datetime(remote_sensing_data['time'].values)
# (to numpy) array the satellite_data variable from the Dataset
if isinstance(remote_sensing_data, xr.Dataset):
satellite_data = remote_sensing_data['satellite_data'].values
else:
# if it is data array
satellite_data = remote_sensing_data.values
# cloud mask (band 0) and data bands (bands 1+)
cloud_mask = satellite_data[:, 0, :, :]
data_bands = satellite_data[:, 1:, :, :].copy()
# expand and mask cloud_mask to match all bands and apply mask
cloud_mask_expanded = np.repeat(cloud_mask[:, np.newaxis, :, :], data_bands.shape[1], axis=1)
data_bands[cloud_mask_expanded == 1] = np.nan
# create complete date range with specified resampling window
# normalize start_date to midnight to ensure proper alignment (00:00, 08:00, 16:00 for 8H, etc.)
start_date = dates_dt.min().normalize() # sets time to 00:00:00
end_date = dates_dt.max()
all_dates_dt = pd.date_range(start_date, end_date, freq=resampling_window)
# init complete array
n_times = len(all_dates_dt)
n_bands = data_bands.shape[1]
n_y = data_bands.shape[2]
n_x = data_bands.shape[3]
data_complete = np.full((n_times, n_bands, n_y, n_x), np.nan)
# fill in existing data - find closest time indices
for i, date in enumerate(dates_dt):
# get the closest index in all_dates_dt
idx = np.argmin(np.abs(all_dates_dt - date))
data_complete[idx, :, :, :] = data_bands[i, :, :, :]
# vectorized interpolation across spatial dimensions
# reshape to (time, band * y * x) for efficient processing
original_shape = data_complete.shape
data_reshaped = data_complete.reshape(n_times, -1)
# time indices for interpolation
time_indices = np.arange(n_times)
# process each pixel series
for pixel_idx in range(data_reshaped.shape[1]):
pixel_values = data_reshaped[:, pixel_idx]
valid_mask = ~np.isnan(pixel_values)
# only interpolate if we have at least 2 valid points !!!
if np.sum(valid_mask) >= 2:
f = interp1d(
time_indices[valid_mask],
pixel_values[valid_mask],
kind='linear',
fill_value='extrapolate', # rethink to fill_value=np.nan (safety reasons)
bounds_error=False
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
data_reshaped[:, pixel_idx] = f(time_indices)
elif np.sum(valid_mask) == 1:
# if only one valid point, fill with that constant value
data_reshaped[:, pixel_idx] = pixel_values[valid_mask][0]
# If no valid points, leave as NaN
# back to original dimensions
data_complete = data_reshaped.reshape(original_shape)
logger.info(f"Number of NaN values after interpolation: {np.isnan(data_complete).sum()}")
logger.info(f"Original time steps: {len(dates_dt)}, Interpolated time steps: {n_times}")
logger.info(f"Resampling frequency: {resampling_window}")
logger.info(f"Processed {n_bands} bands with shape (y={n_y}, x={n_x})")
# xarray Dataset with same structure as input
new_time_coord = all_dates_dt.values
# recreate the full satellite data including cloud mask
# for interpolated dates without original data, set cloud mask to 0 (no clouds)
full_data = np.zeros((n_times, n_bands + 1, n_y, n_x))
# fill cloud mask for original dates only
cloud_mask_complete = np.zeros((n_times, n_y, n_x))
for i, date in enumerate(dates_dt):
idx = np.argmin(np.abs(all_dates_dt - date))
cloud_mask_complete[idx, :, :] = satellite_data[i, 0, :, :]
full_data[:, 0, :, :] = cloud_mask_complete
full_data[:, 1:, :, :] = data_complete
# generate xarray Dataset
interpolated_dataset = xr.Dataset(
{
'satellite_data': (['time', 'band', 'y', 'x'], full_data),
},
coords={
'time': new_time_coord,
'x': remote_sensing_data.coords['x'],
'y': remote_sensing_data.coords['y'],
}
)
# copy over other data variables if they exist
for var in remote_sensing_data.data_vars:
if var != 'satellite_data':
interpolated_dataset[var] = remote_sensing_data[var]
return interpolated_dataset
@task(task_run_name='get_tabular_satellite_{consortium_name}')
def get_tabular_satellite(
consortium_name: str,
satellite_ds: xr.Dataset,
location_df: pd.DataFrame,
method: Literal['closest', 'field_level'] = 'closest',
sensor_field_mapping: Optional[Union[str, Path, Dict]] = None,
field_agg: Literal['mean', 'median'] = 'mean',
#datetime_col: str = 'datetime',
#datastream_name_col: str = 'datastream_name',
#datastream_id_col: str = 'datastream_id'
) -> pd.DataFrame:
"""
Extract satellite data at sensor locations and convert to tabular format.
"""
logger = get_run_logger()
# validate inputs
if method not in ['closest', 'field_level']:
raise ValueError(f"method must be 'closest' or 'field_level', got '{method}'")
if method == 'field_level' and sensor_field_mapping is None:
raise ValueError("sensor_field_mapping required for field_level method")
# ensure datastream_id exists
if datastream_id_col not in location_df.columns:
location_df = location_df.copy()
location_df[datastream_id_col] = range(len(location_df))
# filter location_df only for tensiometers and elmed
filtered_location_df = location_df[
location_df['datastream_name'].str.contains('TN|ELMED|TENSIO', case=False, na=False)
]
# load field mapping if needed
field_geometries = None
if method == 'field_level':
# for services
# field_geometries = _load_field_mapping(sensor_field_mapping, filtered_location_df)
# for fields files
field_geometries = _load_field_mapping(consortium_name, config['consortia_data_folders'][consortium_name], filtered_location_df)
# create coordinate transformer (WGS84 to dataset CRS)
dataset_crs = _infer_dataset_crs(satellite_ds)
transformer = pyproj.Transformer.from_crs("EPSG:4326", dataset_crs, always_xy=True)
# pre-compute pixel indices for all sensors
logger.info(f"Computing pixel masks for sensors...")
sensor_indices = {}
for idx, row in filtered_location_df.iterrows():
sensor_name = row[datastream_name_col]
# ps: x is lat, y is lon (swapped!)
lat, lon = row['x'], row['y']
if method == 'closest':
y_idx, x_idx = _get_closest_pixel_indices(satellite_ds, lat, lon, transformer)
sensor_indices[idx] = {'y': [y_idx], 'x': [x_idx]}
else: # field_level
if sensor_name not in field_geometries:
logger.info(f"Warning: No field geometry for {sensor_name}, using closest pixel")
y_idx, x_idx = _get_closest_pixel_indices(satellite_ds, lat, lon, transformer)
sensor_indices[idx] = {'y': [y_idx], 'x': [x_idx]}
else:
field_geom = field_geometries[sensor_name]
y_indices, x_indices = _get_field_pixel_indices(satellite_ds, field_geom, transformer)
sensor_indices[idx] = {'y': y_indices, 'x': x_indices}
# extract data (efficiently) using xarray operations
logger.info(f"Extracting satellite data...")
results = []
# Get band indices (skip band 0, use bands 1-24)
band_indices = list(range(1, 25))
# get times
times = satellite_ds['time'].values
for idx, row in filtered_location_df.iterrows():
sensor_name = row[datastream_name_col]
sensor_id = row[datastream_id_col]
indices = sensor_indices[idx]
# extract data for this sensor using indexing BEFORE loading into memory
if len(indices['y']) == 0:
logger.info(f"Warning: No pixels found for {sensor_name}")
continue
# only the bands we need (1-24) and the specific pixels
# extract only the data we need WITHOUT loading the full array
data_subset = satellite_ds['satellite_data'].isel(
band=band_indices,
y=xr.DataArray(indices['y'], dims='points'),
x=xr.DataArray(indices['x'], dims='points')
)
# data_subset shape: (time, 24, n_points)
# agg spatially (over points dimension)
if field_agg == 'mean':
aggregated = data_subset.mean(dim='points')
else: # median
aggregated = data_subset.median(dim='points')
# only now load only the aggregated data into memory (avoid overload)
band_data = aggregated.values # Shape: (time, 24)
# create records for each timestamp
for t_idx, timestamp in enumerate(times):
record = {
datetime_col: pd.Timestamp(timestamp),
datastream_name_col: sensor_name,
datastream_id_col: sensor_id
}
# add band values (short term solution)
indices = [
"ndvi", "grvi", "rvi", "rgi", "aci", "maci", "gndvi", "ngrdi", "ngbdi", "bgvi", "brvi",
"wi", "varig", "gli", "g_perc", "ndmi", "ndwi", "reci", "ndre_lower_end",
"ndre_upper_end", "msavi", "arvi", "sipi", "gci"
]
#for b_idx, band_val in enumerate(band_data[t_idx], start=1):
# record[f'band_{b_idx}'] = band_val
for name, band_val in zip(indices, band_data[t_idx]):
record[name] = band_val
results.append(record)
if (idx + 1) % 10 == 0:
logger.info(f"Processed {idx + 1}/{len(location_df)} sensors")
# create df
logger.info(f"Creating final satellite DataFrame...")
df = pd.DataFrame(results)
df = df.sort_values([datastream_name_col, datetime_col]).reset_index(drop=True)
return df
@task(task_run_name="generate_weather_forecast_df")
def generate_weather_forecast_df(df):
res_window = int(resampling_window.split('h')[0])
all_data = []
for dt in df[datetime_col].unique():
start_filter = dt + timedelta(hours=res_window)
end_filter = dt + timedelta(hours=days_weather_forecast*24 + res_window)
df_forecast = df[(df[datetime_col] >= start_filter)&(df[datetime_col] < end_filter)].copy()
df_forecast[f'{datetime_col}_forecast'] = df_forecast[datetime_col]
df_forecast[datetime_col] = dt
all_data.append(df_forecast)
data_out = pd.concat(all_data)
data_out = data_out[[datetime_col, f'{datetime_col}_forecast'] + list(set(data_out.columns) - {datetime_col, f'{datetime_col}_forecast'})].sort_values([datetime_col, f'{datetime_col}_forecast'])
return data_out