| from prefect import task
|
| from prefect.logging import get_run_logger
|
| import torch
|
| import os
|
| from typing import List, Any
|
| import pickle
|
| import yaml
|
| import re
|
| import rasterio
|
| import xarray as xr
|
| import pandas as pd
|
| import numpy as np
|
| from datetime import datetime
|
| from tools.historical_weather import get_historical_weather_data
|
|
|
|
|
| with open('config/params.yml') as file:
|
| config = yaml.safe_load(file)
|
|
|
| datetime_col = config['datetime_col']
|
|
|
|
|
|
|
|
|
| @task(task_run_name="get_consortia")
|
| def get_consortia() -> List[str]:
|
| """
|
| Retrieves the list of consortia from the configuration.
|
|
|
| Returns:
|
| List[str]: List of consortia names.
|
| """
|
| logger = get_run_logger()
|
| consortia = config['consortia']
|
| logger.info(f"We work with the following consortia: {consortia}")
|
| return consortia
|
|
|
|
|
| @task(tags=['write_to_file_{df.name}'], retries=3)
|
| def save_to_file(df: Any, output_file: str) -> None:
|
| """
|
| Saves DataFrames or GeoDataFrames to a file in the specified format.
|
| Args:
|
| df (Any): DataFrame or GeoDataFrame to save.
|
| output_file (str): Path to the output file.
|
| file_format (str, optional): Format to save ('parquet', 'geojson', 'zarr', 'pickle', 'netcdf'/'nc' and torch 'pt' are supported).
|
| """
|
| logger = get_run_logger()
|
| logger.info(f'Saving data to {output_file}...')
|
|
|
| file_format = output_file.split('.')[-1]
|
|
|
| logger.info(f'The output file has format {file_format}.')
|
|
|
| folder = '//'.join(output_file.split('//')[:-1])
|
| os.makedirs(folder, exist_ok=True)
|
|
|
| if file_format == 'parquet':
|
| df.to_parquet(output_file)
|
| elif file_format == 'geojson':
|
| df.to_file(output_file, driver="GeoJSON")
|
| elif file_format == 'zarr':
|
| df.to_zarr(output_file, mode='w')
|
| elif file_format == 'pickle':
|
| with open(output_file, 'wb') as handle:
|
| pickle.dump(df, handle)
|
| elif file_format == 'netcdf' or file_format == 'nc':
|
| df.to_netcdf(output_file, engine='netcdf4', format='NETCDF4')
|
| elif file_format == 'pt':
|
| torch.save(df, output_file)
|
| else:
|
| raise Exception('Format specified is not supported.')
|
|
|
| return
|
|
|
|
|
| @task(task_run_name='read_satellite_data')
|
| def read_satellite_data(consortium_name):
|
| logger = get_run_logger()
|
| logger.info('Collecting satellite data...')
|
|
|
| dir = f"data//01_raw//{config['consortia_data_folders'][consortium_name]}//satellite_data"
|
| file_paths = sorted([f for f in os.listdir(dir) if f.lower().endswith((".tif", ".tiff"))])
|
|
|
| data = []
|
| dates = []
|
|
|
|
|
| pattern = r'\d{4}-\d{2}-\d{2}T\d{2}_\d{2}_\d{2}'
|
|
|
|
|
| for file_path in file_paths:
|
| dates.append(re.search(pattern, file_path).group())
|
| with rasterio.open(os.path.join(dir, file_path)) as src:
|
| transform = src.transform
|
| bands = src.count
|
| band_data = src.read(range(1, bands + 1))
|
| data.append(band_data)
|
|
|
|
|
| datetimes = [datetime.strptime(date, '%Y-%m-%dT%H_%M_%S') for date in dates]
|
|
|
|
|
|
|
| height, width = band_data.shape[1], band_data.shape[2]
|
|
|
|
|
|
|
| x_coords = transform[2] + (np.arange(width) + 0.5) * transform[0]
|
| y_coords = transform[5] + (np.arange(height) + 0.5) * transform[4]
|
|
|
|
|
|
|
| satellite_data = xr.DataArray(
|
| data,
|
| dims=['time', 'band', 'y', 'x'],
|
| coords={
|
| 'time': datetimes,
|
| 'x': x_coords,
|
| 'y': y_coords,
|
| },
|
| name='satellite_data'
|
| )
|
|
|
|
|
|
|
| x_edges = transform[2] + np.arange(width + 1) * transform[0]
|
| y_edges = transform[5] + np.arange(height + 1) * transform[4]
|
|
|
| left_edges = np.tile(x_edges[:-1], (height, 1))
|
| right_edges = np.tile(x_edges[1:], (height, 1))
|
| top_edges = np.tile(y_edges[:-1][:, None], (1, width))
|
| bottom_edges = np.tile(y_edges[1:][:, None], (1, width))
|
|
|
|
|
| dataset = xr.Dataset(
|
| {
|
| "satellite_data": satellite_data,
|
| "pixel_bounds_left": (("y", "x"), left_edges),
|
| "pixel_bounds_right": (("y", "x"), right_edges),
|
| "pixel_bounds_top": (("y", "x"), top_edges),
|
| "pixel_bounds_bottom": (("y", "x"), bottom_edges),
|
| }
|
| )
|
|
|
| return dataset
|
|
|
|
|
| @task(task_run_name='update_historical_weather_data_{consortium_name}')
|
| def update_historical_weather_data(consortium_name, location_ids):
|
| """
|
| Update historical weather data for a consortium.
|
|
|
| Parameters:
|
| -----------
|
| consortium_name : str
|
| Name of the consortium
|
| location_ids : pd.DataFrame
|
| DataFrame with columns: ['datastream_name', 'datastream_id', 'x', 'y']
|
| where x = latitude, y = longitude
|
|
|
| Returns:
|
| --------
|
| pd.DataFrame with historical weather data
|
| """
|
|
|
| logger = get_run_logger()
|
|
|
| start_date = config['start_date']
|
| end_date = config['end_date']
|
|
|
| childs = next(os.walk('data//01_raw//'))
|
| initial_file_exists = False
|
| for child in childs[-1]:
|
| if child.startswith(f'historical_weather_data_{consortium_name}_'):
|
| initial_file_exists = True
|
| current_dates = child.split(f'historical_weather_data_{consortium_name}_')[-1]
|
| current_dates = current_dates.split('.parquet')[0].split('_')
|
| current_start_date = datetime.strptime(current_dates[0], '%Y-%m-%d').date()
|
| current_end_date = datetime.strptime(current_dates[1], '%Y-%m-%d').date()
|
|
|
| if f'historical_weather_data_{consortium_name}_{start_date}_{end_date}.parquet' in childs[-1]:
|
| logger.info(f'Consortium {consortium_name}: historical weather data is already fully downloaded!')
|
| data = pd.read_parquet(
|
| f'data//01_raw//historical_weather_data_{consortium_name}_{current_start_date}_{current_end_date}.parquet')
|
| else:
|
| logger.info(f'Consortium {consortium_name}: some historical weather data need to be downloaded...')
|
|
|
| if not initial_file_exists:
|
| data = get_historical_weather_data(location_ids, start_date=start_date, end_date=end_date)
|
| data[datetime_col] = pd.to_datetime(data['datetime'].dt.tz_localize(None))
|
| data = data.drop(columns=['datetime'])
|
| else:
|
| if current_start_date > start_date:
|
| logger.info('Need to download some data in the past...')
|
|
|
| new_data = get_historical_weather_data(location_ids, start_date=start_date, end_date=current_start_date - datetime.timedelta(days=1))
|
| new_data[datetime_col] = pd.to_datetime(new_data['datetime'].dt.tz_localize(None))
|
| new_data = new_data.drop(columns=['datetime'])
|
|
|
| current_data = pd.read_parquet(
|
| f'data//01_raw//historical_weather_data_{consortium_name}_{current_start_date}_{current_end_date}.parquet')
|
|
|
| data = pd.concat(
|
| [new_data, current_data],
|
| axis=0
|
| ).drop_duplicates(subset=[datetime_col]).sort_values('datetime')
|
|
|
| if current_end_date < end_date:
|
| logger.info('Need to download some more recent data...')
|
|
|
| new_data = get_historical_weather_data(location_ids, start_date=current_end_date, end_date=end_date)
|
| new_data[datetime_col] = pd.to_datetime(new_data['datetime'].dt.tz_localize(None))
|
| new_data = new_data.drop(columns=['datetime'])
|
|
|
| current_data = pd.read_parquet(
|
| f'data//01_raw//historical_weather_data_{consortium_name}_{current_start_date}_{current_end_date}.parquet')
|
|
|
| data = pd.concat(
|
| [new_data, current_data],
|
| axis=0
|
| ).drop_duplicates(subset=[datetime_col, 'datastream_name']).sort_values(datetime_col)
|
|
|
| logger.info('Old historical weather data are being removed...')
|
| os.remove(
|
| f'data//01_raw//historical_weather_data_{consortium_name}_{current_start_date}_{current_end_date}.parquet')
|
|
|
| logger.info('New historical weather data are being saved...')
|
| save_to_file(df=data,
|
| output_file=f'data//01_raw//historical_weather_data_{consortium_name}_{start_date}_{end_date}.parquet')
|
| logger.info('Done!')
|
|
|
| return data |