paolog-fbk's picture
Upload folder using huggingface_hub
64ab846 verified
from prefect import flow, task
from prefect.task_runners import ThreadPoolTaskRunner
from prefect.logging import get_run_logger
import pandas as pd
import os
import yaml
import pickle
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from xgboost import XGBRegressor
from sklearn.metrics import r2_score
from pipelines import xgcast_pipeline
from pipelines.data_collection_pipeline import get_satellite_data
from pipelines.preprocessing_pipeline import save_to_file, read_satellite_data, update_historical_weather_data
from pipelines.resample_impute_pipeline import interpolate_satellite_data, get_tabular_satellite, generate_weather_forecast_df
from pipelines.xgcast_pipeline import normalise_df
from pipelines.model_preparation_pipeline import model_preparation_pipeline, data_availability
with open('config/params.yml') as file:
config = yaml.safe_load(file)
with open('config/xgcast_params.yml') as file:
model_params = yaml.safe_load(file)
datetime_col = config['datetime_col']
value_col = config['value_col']
datastream_name_col = config['datastream_name_col']
resampling_window = config['resampling_window']
resampling = int(resampling_window.split('h')[0])
num_predictions = 3 * (24 // resampling)
### Flows
@flow(name='xgcast_preparation_pipeline', retries=1, task_runner=ThreadPoolTaskRunner())
def xgcast_run(consortium_name):
logger = get_run_logger()
logger.info(f'Starting XGCast run for consortium {consortium_name}!')
has_weather_data, has_crop_data, has_soil_data, has_remote_sensing_data = True, True, True, True
if not os.path.exists(f'data//03_primary//field_sensor_data_{consortium_name}.parquet'):
logger.error(f'Table of field sensor measurements not found for consortium {consortium_name}. Stopping.')
return
if not os.path.exists(f'data//03_primary//irrigation_data_{consortium_name}.parquet'):
logger.error(f'Table of irrigation data not found for consortium {consortium_name}. Stopping.')
return
if not os.path.exists(f'data//03_primary//locations_ids_{consortium_name}.parquet'):
logger.error(f'Table of locations ids not found for consortium {consortium_name}. Stopping.')
return
if not os.path.exists(f'data//03_primary//crop_type_data_{consortium_name}.pickle'):
has_crop_data = False
logger.info(f'Crop data not found for consortium {consortium_name}')
if not os.path.exists(f'data//03_primary//soil_type_data_{consortium_name}.parquet'):
has_soil_data = False
logger.info(f'Soil data not found for consortium {consortium_name}')
if not os.path.exists(f'data//03_primary//weather_data_{consortium_name}.parquet'):
has_weather_data = False
logger.info(f'On-field weather data not found for consortium {consortium_name}')
if not os.path.exists(f'data//03_primary//historical_weather_data_{consortium_name}.parquet'):
logger.info(f'Public weather data not found for consortium {consortium_name}. Downloading data...')
locations_ids = pd.read_parquet(f'data//03_primary//locations_ids_{consortium_name}.parquet')
historical_weather_data = update_historical_weather_data(consortium_name=consortium_name, locations_ids=locations_ids)
save_to_file(df=historical_weather_data, output_file=f'data//03_primary//historical_weather_data_{consortium_name}.parquet')
logger.info(f'Public weather data has been downloaded and saved for consortium {consortium_name}.')
if not os.path.exists(f'data//03_primary//forecasted_weather_data_{consortium_name}.parquet'):
# for model training purposes we use 'real' weather forecast data
logger.info(f'Weather forecast data not found for consortium {consortium_name}. Generating data for model training purposes...')
historical_weather_data = pd.read_parquet(f'data//03_primary//historical_weather_data_{consortium_name}.parquet')
forecasted_weather_data = generate_weather_forecast_df(historical_weather_data.reset_index())
save_to_file(df=forecasted_weather_data, output_file=f'data//03_primary//forecasted_weather_data_{consortium_name}.parquet')
logger.info(f'Weather forecast data has been downloaded and saved for consortium {consortium_name}.')
if not os.path.exists(f'data//03_primary//remote_sensing_data_final_{consortium_name}.parquet'):
try:
# TODO: make sure such params exist. AS THE USED TASKS WERE MEANT TO BE PRIVATE, THERE ARE REQUIRED PARAMS THAT MIGHT NOT EXIST
# pre anonmym params are being called in copernicus data request as bbox and data folder are needed
# download satellite data
# todo: pre anonmym params are being called in copernicus data request as bbox and data folder are needed
# todo: start_date end_date and request_script are init in the file, so they pre_anonym_params is needed
get_satellite_data.submit(
consortium=consortium_name,
overwrite=False
)
# transform into nc file containing all data
remote_sensing_data = read_satellite_data.submit(consortium_name=consortium_name).result()
# we save only final
#save_to_file(df=remote_sensing_data,
# output_file=f'data//02_preprocessed//remote_sensing_data_{consortium_name}.nc')
# resample
satellite_interpolated = interpolate_satellite_data.submit(remote_sensing_data).result()
# agg
locations_ids = pd.read_parquet(f'data//03_primary//locations_ids_{consortium_name}.parquet')
# todo: requires pre_anonym_params - 'spatial_agg_method' and 'field_agg'
spatial_agg_method = config['spatial_agg_method']
field_agg = config['field_agg']
remote_sensing_data_final = get_tabular_satellite.submit(
consortium_name,
satellite_interpolated,
locations_ids,
method=spatial_agg_method,
sensor_field_mapping=f'config/sensor_field_mapping.yaml',
field_agg=field_agg).result()
# save
save_to_file(df = remote_sensing_data_final, output_file=f'data//03_primary//remote_sensing_data_final_{consortium_name}.parquet')
except:
has_remote_sensing_data = False
logger.warning('Warning: satellite data could not be downloaded.')
data_availability[consortium_name] = {'has_weather_data': has_weather_data, 'has_crop_data': has_crop_data, 'has_soil_data': has_soil_data, 'has_remote_sensing_data': has_remote_sensing_data}
model_preparation_pipeline(override_consortia=[consortium_name], override_data_availability=data_availability[consortium_name])
xgcast_pipeline.xgcast_preparation_pipeline(override_consortia=[consortium_name], override_data_availability=data_availability[consortium_name])
xgcast_pipeline.xgcast_model_pipeline(override_consortia=[consortium_name])