| from prefect import flow, task
|
| from prefect.task_runners import ThreadPoolTaskRunner
|
| from prefect.logging import get_run_logger
|
| import pandas as pd
|
| import os
|
| import yaml
|
| import pickle
|
| import plotly.graph_objects as go
|
| from plotly.subplots import make_subplots
|
| from xgboost import XGBRegressor
|
| from sklearn.metrics import r2_score
|
| from pipelines import xgcast_pipeline
|
| from pipelines.data_collection_pipeline import get_satellite_data
|
| from pipelines.preprocessing_pipeline import save_to_file, read_satellite_data, update_historical_weather_data
|
| from pipelines.resample_impute_pipeline import interpolate_satellite_data, get_tabular_satellite, generate_weather_forecast_df
|
| from pipelines.xgcast_pipeline import normalise_df
|
| from pipelines.model_preparation_pipeline import model_preparation_pipeline, data_availability
|
|
|
|
|
| with open('config/params.yml') as file:
|
| config = yaml.safe_load(file)
|
| with open('config/xgcast_params.yml') as file:
|
| model_params = yaml.safe_load(file)
|
|
|
|
|
| datetime_col = config['datetime_col']
|
| value_col = config['value_col']
|
| datastream_name_col = config['datastream_name_col']
|
|
|
| resampling_window = config['resampling_window']
|
| resampling = int(resampling_window.split('h')[0])
|
| num_predictions = 3 * (24 // resampling)
|
|
|
|
|
|
|
|
|
| @flow(name='xgcast_preparation_pipeline', retries=1, task_runner=ThreadPoolTaskRunner())
|
| def xgcast_run(consortium_name):
|
| logger = get_run_logger()
|
| logger.info(f'Starting XGCast run for consortium {consortium_name}!')
|
|
|
| has_weather_data, has_crop_data, has_soil_data, has_remote_sensing_data = True, True, True, True
|
|
|
| if not os.path.exists(f'data//03_primary//field_sensor_data_{consortium_name}.parquet'):
|
| logger.error(f'Table of field sensor measurements not found for consortium {consortium_name}. Stopping.')
|
| return
|
|
|
| if not os.path.exists(f'data//03_primary//irrigation_data_{consortium_name}.parquet'):
|
| logger.error(f'Table of irrigation data not found for consortium {consortium_name}. Stopping.')
|
| return
|
|
|
| if not os.path.exists(f'data//03_primary//locations_ids_{consortium_name}.parquet'):
|
| logger.error(f'Table of locations ids not found for consortium {consortium_name}. Stopping.')
|
| return
|
|
|
| if not os.path.exists(f'data//03_primary//crop_type_data_{consortium_name}.pickle'):
|
| has_crop_data = False
|
| logger.info(f'Crop data not found for consortium {consortium_name}')
|
|
|
| if not os.path.exists(f'data//03_primary//soil_type_data_{consortium_name}.parquet'):
|
| has_soil_data = False
|
| logger.info(f'Soil data not found for consortium {consortium_name}')
|
|
|
| if not os.path.exists(f'data//03_primary//weather_data_{consortium_name}.parquet'):
|
| has_weather_data = False
|
| logger.info(f'On-field weather data not found for consortium {consortium_name}')
|
|
|
| if not os.path.exists(f'data//03_primary//historical_weather_data_{consortium_name}.parquet'):
|
| logger.info(f'Public weather data not found for consortium {consortium_name}. Downloading data...')
|
| locations_ids = pd.read_parquet(f'data//03_primary//locations_ids_{consortium_name}.parquet')
|
| historical_weather_data = update_historical_weather_data(consortium_name=consortium_name, locations_ids=locations_ids)
|
| save_to_file(df=historical_weather_data, output_file=f'data//03_primary//historical_weather_data_{consortium_name}.parquet')
|
| logger.info(f'Public weather data has been downloaded and saved for consortium {consortium_name}.')
|
|
|
| if not os.path.exists(f'data//03_primary//forecasted_weather_data_{consortium_name}.parquet'):
|
|
|
| logger.info(f'Weather forecast data not found for consortium {consortium_name}. Generating data for model training purposes...')
|
| historical_weather_data = pd.read_parquet(f'data//03_primary//historical_weather_data_{consortium_name}.parquet')
|
| forecasted_weather_data = generate_weather_forecast_df(historical_weather_data.reset_index())
|
| save_to_file(df=forecasted_weather_data, output_file=f'data//03_primary//forecasted_weather_data_{consortium_name}.parquet')
|
| logger.info(f'Weather forecast data has been downloaded and saved for consortium {consortium_name}.')
|
|
|
| if not os.path.exists(f'data//03_primary//remote_sensing_data_final_{consortium_name}.parquet'):
|
| try:
|
|
|
|
|
|
|
|
|
|
|
| get_satellite_data.submit(
|
| consortium=consortium_name,
|
| overwrite=False
|
| )
|
|
|
| remote_sensing_data = read_satellite_data.submit(consortium_name=consortium_name).result()
|
|
|
|
|
|
|
|
|
| satellite_interpolated = interpolate_satellite_data.submit(remote_sensing_data).result()
|
|
|
| locations_ids = pd.read_parquet(f'data//03_primary//locations_ids_{consortium_name}.parquet')
|
|
|
| spatial_agg_method = config['spatial_agg_method']
|
| field_agg = config['field_agg']
|
| remote_sensing_data_final = get_tabular_satellite.submit(
|
| consortium_name,
|
| satellite_interpolated,
|
| locations_ids,
|
| method=spatial_agg_method,
|
| sensor_field_mapping=f'config/sensor_field_mapping.yaml',
|
| field_agg=field_agg).result()
|
|
|
| save_to_file(df = remote_sensing_data_final, output_file=f'data//03_primary//remote_sensing_data_final_{consortium_name}.parquet')
|
| except:
|
| has_remote_sensing_data = False
|
| logger.warning('Warning: satellite data could not be downloaded.')
|
|
|
|
|
| data_availability[consortium_name] = {'has_weather_data': has_weather_data, 'has_crop_data': has_crop_data, 'has_soil_data': has_soil_data, 'has_remote_sensing_data': has_remote_sensing_data}
|
|
|
| model_preparation_pipeline(override_consortia=[consortium_name], override_data_availability=data_availability[consortium_name])
|
| xgcast_pipeline.xgcast_preparation_pipeline(override_consortia=[consortium_name], override_data_availability=data_availability[consortium_name])
|
| xgcast_pipeline.xgcast_model_pipeline(override_consortia=[consortium_name])
|
|
|