| from prefect import task
|
| import numpy as np
|
| import pandas as pd
|
| from prefect.logging import get_run_logger
|
| import yaml
|
| import xarray as xr
|
| import warnings
|
| from scipy.interpolate import interp1d
|
| import pyproj
|
| from datetime import timedelta
|
|
|
| from pathlib import Path
|
| from typing import Literal, Optional, Dict, Union
|
|
|
| from tools.gis_utils import _get_field_pixel_indices, _get_closest_pixel_indices, _infer_dataset_crs, _load_field_mapping
|
|
|
|
|
| with open('config/params.yml') as file:
|
| config = yaml.safe_load(file)
|
|
|
| datetime_col = config['datetime_col']
|
| datastream_id_col = config['datastream_id_col']
|
| datastream_name_col = config['datastream_name_col']
|
|
|
| resampling_window = config['resampling_window']
|
|
|
| days_weather_forecast = config['days_weather_forecast']
|
|
|
|
|
|
|
|
|
| @task(task_run_name='interpolate_satellite_data')
|
| def interpolate_satellite_data(remote_sensing_data):
|
| """
|
| Efficiently interpolates satellite data across all bands using cloud mask.
|
| Parameters:
|
| -----------
|
| remote_sensing_data : xarray.Dataset / xarray.DataArray
|
| Dataset with dimensions (time, band, y, x) where band 0 is cloud cover
|
| Must have a 'time' coordinate (from the previous pipeline)
|
|
|
| Returns:
|
| --------
|
| xarray.Dataset
|
| Interpolated dataset with same structure as input, with complete time series at specified frequency
|
| """
|
|
|
| logger = get_run_logger()
|
|
|
|
|
| dates_dt = pd.to_datetime(remote_sensing_data['time'].values)
|
|
|
| if isinstance(remote_sensing_data, xr.Dataset):
|
| satellite_data = remote_sensing_data['satellite_data'].values
|
| else:
|
|
|
| satellite_data = remote_sensing_data.values
|
|
|
| cloud_mask = satellite_data[:, 0, :, :]
|
| data_bands = satellite_data[:, 1:, :, :].copy()
|
|
|
|
|
| cloud_mask_expanded = np.repeat(cloud_mask[:, np.newaxis, :, :], data_bands.shape[1], axis=1)
|
| data_bands[cloud_mask_expanded == 1] = np.nan
|
|
|
|
|
|
|
| start_date = dates_dt.min().normalize()
|
| end_date = dates_dt.max()
|
| all_dates_dt = pd.date_range(start_date, end_date, freq=resampling_window)
|
|
|
|
|
| n_times = len(all_dates_dt)
|
| n_bands = data_bands.shape[1]
|
| n_y = data_bands.shape[2]
|
| n_x = data_bands.shape[3]
|
| data_complete = np.full((n_times, n_bands, n_y, n_x), np.nan)
|
|
|
|
|
| for i, date in enumerate(dates_dt):
|
|
|
| idx = np.argmin(np.abs(all_dates_dt - date))
|
| data_complete[idx, :, :, :] = data_bands[i, :, :, :]
|
|
|
|
|
|
|
| original_shape = data_complete.shape
|
| data_reshaped = data_complete.reshape(n_times, -1)
|
|
|
| time_indices = np.arange(n_times)
|
|
|
| for pixel_idx in range(data_reshaped.shape[1]):
|
| pixel_values = data_reshaped[:, pixel_idx]
|
| valid_mask = ~np.isnan(pixel_values)
|
|
|
| if np.sum(valid_mask) >= 2:
|
| f = interp1d(
|
| time_indices[valid_mask],
|
| pixel_values[valid_mask],
|
| kind='linear',
|
| fill_value='extrapolate',
|
| bounds_error=False
|
| )
|
| with warnings.catch_warnings():
|
| warnings.filterwarnings("ignore", category=RuntimeWarning)
|
| data_reshaped[:, pixel_idx] = f(time_indices)
|
|
|
| elif np.sum(valid_mask) == 1:
|
|
|
| data_reshaped[:, pixel_idx] = pixel_values[valid_mask][0]
|
|
|
|
|
|
|
| data_complete = data_reshaped.reshape(original_shape)
|
| logger.info(f"Number of NaN values after interpolation: {np.isnan(data_complete).sum()}")
|
| logger.info(f"Original time steps: {len(dates_dt)}, Interpolated time steps: {n_times}")
|
| logger.info(f"Resampling frequency: {resampling_window}")
|
| logger.info(f"Processed {n_bands} bands with shape (y={n_y}, x={n_x})")
|
|
|
|
|
| new_time_coord = all_dates_dt.values
|
|
|
|
|
|
|
| full_data = np.zeros((n_times, n_bands + 1, n_y, n_x))
|
|
|
|
|
| cloud_mask_complete = np.zeros((n_times, n_y, n_x))
|
| for i, date in enumerate(dates_dt):
|
| idx = np.argmin(np.abs(all_dates_dt - date))
|
| cloud_mask_complete[idx, :, :] = satellite_data[i, 0, :, :]
|
| full_data[:, 0, :, :] = cloud_mask_complete
|
| full_data[:, 1:, :, :] = data_complete
|
|
|
| interpolated_dataset = xr.Dataset(
|
| {
|
| 'satellite_data': (['time', 'band', 'y', 'x'], full_data),
|
| },
|
| coords={
|
| 'time': new_time_coord,
|
| 'x': remote_sensing_data.coords['x'],
|
| 'y': remote_sensing_data.coords['y'],
|
| }
|
| )
|
|
|
| for var in remote_sensing_data.data_vars:
|
| if var != 'satellite_data':
|
| interpolated_dataset[var] = remote_sensing_data[var]
|
| return interpolated_dataset
|
|
|
|
|
| @task(task_run_name='get_tabular_satellite_{consortium_name}')
|
| def get_tabular_satellite(
|
| consortium_name: str,
|
| satellite_ds: xr.Dataset,
|
| location_df: pd.DataFrame,
|
| method: Literal['closest', 'field_level'] = 'closest',
|
| sensor_field_mapping: Optional[Union[str, Path, Dict]] = None,
|
| field_agg: Literal['mean', 'median'] = 'mean',
|
|
|
|
|
|
|
| ) -> pd.DataFrame:
|
| """
|
| Extract satellite data at sensor locations and convert to tabular format.
|
| """
|
| logger = get_run_logger()
|
|
|
|
|
| if method not in ['closest', 'field_level']:
|
| raise ValueError(f"method must be 'closest' or 'field_level', got '{method}'")
|
|
|
| if method == 'field_level' and sensor_field_mapping is None:
|
| raise ValueError("sensor_field_mapping required for field_level method")
|
|
|
|
|
| if datastream_id_col not in location_df.columns:
|
| location_df = location_df.copy()
|
| location_df[datastream_id_col] = range(len(location_df))
|
|
|
|
|
| filtered_location_df = location_df[
|
| location_df['datastream_name'].str.contains('TN|ELMED|TENSIO', case=False, na=False)
|
| ]
|
|
|
|
|
| field_geometries = None
|
| if method == 'field_level':
|
|
|
|
|
|
|
| field_geometries = _load_field_mapping(consortium_name, config['consortia_data_folders'][consortium_name], filtered_location_df)
|
|
|
|
|
| dataset_crs = _infer_dataset_crs(satellite_ds)
|
| transformer = pyproj.Transformer.from_crs("EPSG:4326", dataset_crs, always_xy=True)
|
|
|
|
|
| logger.info(f"Computing pixel masks for sensors...")
|
| sensor_indices = {}
|
|
|
| for idx, row in filtered_location_df.iterrows():
|
| sensor_name = row[datastream_name_col]
|
|
|
| lat, lon = row['x'], row['y']
|
|
|
| if method == 'closest':
|
| y_idx, x_idx = _get_closest_pixel_indices(satellite_ds, lat, lon, transformer)
|
| sensor_indices[idx] = {'y': [y_idx], 'x': [x_idx]}
|
| else:
|
| if sensor_name not in field_geometries:
|
| logger.info(f"Warning: No field geometry for {sensor_name}, using closest pixel")
|
| y_idx, x_idx = _get_closest_pixel_indices(satellite_ds, lat, lon, transformer)
|
| sensor_indices[idx] = {'y': [y_idx], 'x': [x_idx]}
|
| else:
|
| field_geom = field_geometries[sensor_name]
|
| y_indices, x_indices = _get_field_pixel_indices(satellite_ds, field_geom, transformer)
|
| sensor_indices[idx] = {'y': y_indices, 'x': x_indices}
|
|
|
|
|
| logger.info(f"Extracting satellite data...")
|
| results = []
|
|
|
|
|
| band_indices = list(range(1, 25))
|
|
|
|
|
| times = satellite_ds['time'].values
|
|
|
| for idx, row in filtered_location_df.iterrows():
|
| sensor_name = row[datastream_name_col]
|
| sensor_id = row[datastream_id_col]
|
| indices = sensor_indices[idx]
|
|
|
|
|
| if len(indices['y']) == 0:
|
| logger.info(f"Warning: No pixels found for {sensor_name}")
|
| continue
|
|
|
|
|
|
|
| data_subset = satellite_ds['satellite_data'].isel(
|
| band=band_indices,
|
| y=xr.DataArray(indices['y'], dims='points'),
|
| x=xr.DataArray(indices['x'], dims='points')
|
| )
|
|
|
|
|
|
|
| if field_agg == 'mean':
|
| aggregated = data_subset.mean(dim='points')
|
| else:
|
| aggregated = data_subset.median(dim='points')
|
|
|
|
|
| band_data = aggregated.values
|
|
|
|
|
| for t_idx, timestamp in enumerate(times):
|
| record = {
|
| datetime_col: pd.Timestamp(timestamp),
|
| datastream_name_col: sensor_name,
|
| datastream_id_col: sensor_id
|
| }
|
|
|
|
|
| indices = [
|
| "ndvi", "grvi", "rvi", "rgi", "aci", "maci", "gndvi", "ngrdi", "ngbdi", "bgvi", "brvi",
|
| "wi", "varig", "gli", "g_perc", "ndmi", "ndwi", "reci", "ndre_lower_end",
|
| "ndre_upper_end", "msavi", "arvi", "sipi", "gci"
|
| ]
|
|
|
|
|
|
|
|
|
| for name, band_val in zip(indices, band_data[t_idx]):
|
| record[name] = band_val
|
|
|
| results.append(record)
|
|
|
| if (idx + 1) % 10 == 0:
|
| logger.info(f"Processed {idx + 1}/{len(location_df)} sensors")
|
|
|
|
|
| logger.info(f"Creating final satellite DataFrame...")
|
| df = pd.DataFrame(results)
|
| df = df.sort_values([datastream_name_col, datetime_col]).reset_index(drop=True)
|
|
|
| return df
|
|
|
| @task(task_run_name="generate_weather_forecast_df")
|
| def generate_weather_forecast_df(df):
|
| res_window = int(resampling_window.split('h')[0])
|
| all_data = []
|
| for dt in df[datetime_col].unique():
|
| start_filter = dt + timedelta(hours=res_window)
|
| end_filter = dt + timedelta(hours=days_weather_forecast*24 + res_window)
|
| df_forecast = df[(df[datetime_col] >= start_filter)&(df[datetime_col] < end_filter)].copy()
|
| df_forecast[f'{datetime_col}_forecast'] = df_forecast[datetime_col]
|
| df_forecast[datetime_col] = dt
|
| all_data.append(df_forecast)
|
|
|
| data_out = pd.concat(all_data)
|
| data_out = data_out[[datetime_col, f'{datetime_col}_forecast'] + list(set(data_out.columns) - {datetime_col, f'{datetime_col}_forecast'})].sort_values([datetime_col, f'{datetime_col}_forecast'])
|
| return data_out |