| | from typing import Callable, TypedDict |
| | from matplotlib.figure import figaspect |
| | import pandas as pd |
| | from plotly.graph_objects import Figure |
| | import plotly.graph_objects as go |
| | import plotly.express as px |
| |
|
| | from climateqa.engine.talk_to_data.sql_query import ( |
| | indicator_for_given_year_query, |
| | indicator_per_year_at_location_query, |
| | ) |
| | from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT |
| |
|
| |
|
| |
|
| |
|
| | class Plot(TypedDict): |
| | """Represents a plot configuration in the DRIAS system. |
| | |
| | This class defines the structure for configuring different types of plots |
| | that can be generated from climate data. |
| | |
| | Attributes: |
| | name (str): The name of the plot type |
| | description (str): A description of what the plot shows |
| | params (list[str]): List of required parameters for the plot |
| | plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot |
| | sql_query (Callable[..., str]): Function to generate the SQL query for the plot |
| | """ |
| | name: str |
| | description: str |
| | params: list[str] |
| | plot_function: Callable[..., Callable[..., Figure]] |
| | sql_query: Callable[..., str] |
| |
|
| |
|
| | def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]: |
| | """Generates a function to plot indicator evolution over time at a location. |
| | |
| | This function creates a line plot showing how a climate indicator changes |
| | over time at a specific location. It handles temperature, precipitation, |
| | and other climate indicators. |
| | |
| | Args: |
| | params (dict): Dictionary containing: |
| | - indicator_column (str): The column name for the indicator |
| | - location (str): The location to plot |
| | - model (str): The climate model to use |
| | |
| | Returns: |
| | Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure |
| | |
| | Example: |
| | >>> plot_func = plot_indicator_evolution_at_location({ |
| | ... 'indicator_column': 'mean_temperature', |
| | ... 'location': 'Paris', |
| | ... 'model': 'ALL' |
| | ... }) |
| | >>> fig = plot_func(df) |
| | """ |
| | indicator = params["indicator_column"] |
| | location = params["location"] |
| | indicator_label = " ".join([word.capitalize() for word in indicator.split("_")]) |
| | unit = INDICATOR_TO_UNIT.get(indicator, "") |
| |
|
| | def plot_data(df: pd.DataFrame) -> Figure: |
| | """Generates the actual plot from the data. |
| | |
| | Args: |
| | df (pd.DataFrame): DataFrame containing the data to plot |
| | |
| | Returns: |
| | Figure: A plotly Figure object showing the indicator evolution |
| | """ |
| | fig = go.Figure() |
| | if df['model'].nunique() != 1: |
| | df_avg = df.groupby("year", as_index=False)[indicator].mean() |
| |
|
| | |
| | indicators = df_avg[indicator].astype(float).tolist() |
| | years = df_avg["year"].astype(int).tolist() |
| |
|
| | |
| | rolling_window = 10 |
| | sliding_averages = ( |
| | df_avg[indicator] |
| | .rolling(window=rolling_window, min_periods=rolling_window) |
| | .mean() |
| | .astype(float) |
| | .tolist() |
| | ) |
| | model_label = "Model Average" |
| |
|
| | |
| | if len([x for x in sliding_averages if pd.notna(x)]) > 0: |
| | |
| | fig.add_scatter( |
| | x=years, |
| | y=sliding_averages, |
| | mode="lines", |
| | name="10 years rolling average", |
| | line=dict(dash="dash"), |
| | marker=dict(color="#d62728"), |
| | hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>" |
| | ) |
| |
|
| | else: |
| | df_model = df |
| |
|
| | |
| | indicators = df_model[indicator].astype(float).tolist() |
| | years = df_model["year"].astype(int).tolist() |
| |
|
| | |
| | rolling_window = 10 |
| | sliding_averages = ( |
| | df_model[indicator] |
| | .rolling(window=rolling_window, min_periods=rolling_window) |
| | .mean() |
| | .astype(float) |
| | .tolist() |
| | ) |
| | model_label = f"Model : {df['model'].unique()[0]}" |
| |
|
| | |
| | if len([x for x in sliding_averages if pd.notna(x)]) > 0: |
| | |
| | fig.add_scatter( |
| | x=years, |
| | y=sliding_averages, |
| | mode="lines", |
| | name="10 years rolling average", |
| | line=dict(dash="dash"), |
| | marker=dict(color="#d62728"), |
| | hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>" |
| | ) |
| |
|
| | |
| | fig.add_scatter( |
| | x=years, |
| | y=indicators, |
| | name=f"Yearly {indicator_label}", |
| | mode="lines", |
| | marker=dict(color="#1f77b4"), |
| | hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>" |
| | ) |
| | fig.update_layout( |
| | title=f"Plot of {indicator_label} in {location} ({model_label})", |
| | xaxis_title="Year", |
| | yaxis_title=f"{indicator_label} ({unit})", |
| | template="plotly_white", |
| | ) |
| | return fig |
| |
|
| | return plot_data |
| |
|
| |
|
| | indicator_evolution_at_location: Plot = { |
| | "name": "Indicator evolution at location", |
| | "description": "Plot an evolution of the indicator at a certain location", |
| | "params": ["indicator_column", "location", "model"], |
| | "plot_function": plot_indicator_evolution_at_location, |
| | "sql_query": indicator_per_year_at_location_query, |
| | } |
| |
|
| |
|
| | def plot_indicator_number_of_days_per_year_at_location( |
| | params: dict, |
| | ) -> Callable[..., Figure]: |
| | """Generates a function to plot the number of days per year for an indicator. |
| | |
| | This function creates a bar chart showing the frequency of certain climate |
| | events (like days above a temperature threshold) per year at a specific location. |
| | |
| | Args: |
| | params (dict): Dictionary containing: |
| | - indicator_column (str): The column name for the indicator |
| | - location (str): The location to plot |
| | - model (str): The climate model to use |
| | |
| | Returns: |
| | Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure |
| | """ |
| | indicator = params["indicator_column"] |
| | location = params["location"] |
| | indicator_label = " ".join([word.capitalize() for word in indicator.split("_")]) |
| | unit = INDICATOR_TO_UNIT.get(indicator, "") |
| |
|
| | def plot_data(df: pd.DataFrame) -> Figure: |
| | """Generate the figure thanks to the dataframe |
| | |
| | Args: |
| | df (pd.DataFrame): pandas dataframe with the required data |
| | |
| | Returns: |
| | Figure: Plotly figure |
| | """ |
| | fig = go.Figure() |
| | if df['model'].nunique() != 1: |
| | df_avg = df.groupby("year", as_index=False)[indicator].mean() |
| |
|
| | |
| | indicators = df_avg[indicator].astype(float).tolist() |
| | years = df_avg["year"].astype(int).tolist() |
| | model_label = "Model Average" |
| |
|
| | else: |
| | df_model = df |
| | |
| | indicators = df_model[indicator].astype(float).tolist() |
| | years = df_model["year"].astype(int).tolist() |
| | model_label = f"Model : {df['model'].unique()[0]}" |
| |
|
| |
|
| | |
| | fig.add_trace( |
| | go.Bar( |
| | x=years, |
| | y=indicators, |
| | width=0.5, |
| | marker=dict(color="#1f77b4"), |
| | hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>" |
| | ) |
| | ) |
| |
|
| | fig.update_layout( |
| | title=f"{indicator_label} in {location} ({model_label})", |
| | xaxis_title="Year", |
| | yaxis_title=f"{indicator_label} ({unit})", |
| | yaxis=dict(range=[0, max(indicators)]), |
| | bargap=0.5, |
| | template="plotly_white", |
| | ) |
| |
|
| | return fig |
| |
|
| | return plot_data |
| |
|
| |
|
| | indicator_number_of_days_per_year_at_location: Plot = { |
| | "name": "Indicator number of days per year at location", |
| | "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.", |
| | "params": ["indicator_column", "location", "model"], |
| | "plot_function": plot_indicator_number_of_days_per_year_at_location, |
| | "sql_query": indicator_per_year_at_location_query, |
| | } |
| |
|
| |
|
| | def plot_distribution_of_indicator_for_given_year( |
| | params: dict, |
| | ) -> Callable[..., Figure]: |
| | """Generates a function to plot the distribution of an indicator for a year. |
| | |
| | This function creates a histogram showing the distribution of a climate |
| | indicator across different locations for a specific year. |
| | |
| | Args: |
| | params (dict): Dictionary containing: |
| | - indicator_column (str): The column name for the indicator |
| | - year (str): The year to plot |
| | - model (str): The climate model to use |
| | |
| | Returns: |
| | Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure |
| | """ |
| | indicator = params["indicator_column"] |
| | year = params["year"] |
| | indicator_label = " ".join([word.capitalize() for word in indicator.split("_")]) |
| | unit = INDICATOR_TO_UNIT.get(indicator, "") |
| |
|
| | def plot_data(df: pd.DataFrame) -> Figure: |
| | """Generate the figure thanks to the dataframe |
| | |
| | Args: |
| | df (pd.DataFrame): pandas dataframe with the required data |
| | |
| | Returns: |
| | Figure: Plotly figure |
| | """ |
| | fig = go.Figure() |
| | if df['model'].nunique() != 1: |
| | df_avg = df.groupby(["latitude", "longitude"], as_index=False)[ |
| | indicator |
| | ].mean() |
| |
|
| | |
| | indicators = df_avg[indicator].astype(float).tolist() |
| | model_label = "Model Average" |
| |
|
| | else: |
| | df_model = df |
| |
|
| | |
| | indicators = df_model[indicator].astype(float).tolist() |
| | model_label = f"Model : {df['model'].unique()[0]}" |
| |
|
| |
|
| | fig.add_trace( |
| | go.Histogram( |
| | x=indicators, |
| | opacity=0.8, |
| | histnorm="percent", |
| | marker=dict(color="#1f77b4"), |
| | hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>" |
| | ) |
| | ) |
| |
|
| | fig.update_layout( |
| | title=f"Distribution of {indicator_label} in {year} ({model_label})", |
| | xaxis_title=f"{indicator_label} ({unit})", |
| | yaxis_title="Frequency (%)", |
| | plot_bgcolor="rgba(0, 0, 0, 0)", |
| | showlegend=False, |
| | ) |
| |
|
| | return fig |
| |
|
| | return plot_data |
| |
|
| |
|
| | distribution_of_indicator_for_given_year: Plot = { |
| | "name": "Distribution of an indicator for a given year", |
| | "description": "Plot an histogram of the distribution for a given year of the values of an indicator", |
| | "params": ["indicator_column", "model", "year"], |
| | "plot_function": plot_distribution_of_indicator_for_given_year, |
| | "sql_query": indicator_for_given_year_query, |
| | } |
| |
|
| |
|
| | def plot_map_of_france_of_indicator_for_given_year( |
| | params: dict, |
| | ) -> Callable[..., Figure]: |
| | """Generates a function to plot a map of France for an indicator. |
| | |
| | This function creates a choropleth map of France showing the spatial |
| | distribution of a climate indicator for a specific year. |
| | |
| | Args: |
| | params (dict): Dictionary containing: |
| | - indicator_column (str): The column name for the indicator |
| | - year (str): The year to plot |
| | - model (str): The climate model to use |
| | |
| | Returns: |
| | Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure |
| | """ |
| | indicator = params["indicator_column"] |
| | year = params["year"] |
| | indicator_label = " ".join([word.capitalize() for word in indicator.split("_")]) |
| | unit = INDICATOR_TO_UNIT.get(indicator, "") |
| |
|
| | def plot_data(df: pd.DataFrame) -> Figure: |
| | fig = go.Figure() |
| | if df['model'].nunique() != 1: |
| | df_avg = df.groupby(["latitude", "longitude"], as_index=False)[ |
| | indicator |
| | ].mean() |
| |
|
| | indicators = df_avg[indicator].astype(float).tolist() |
| | latitudes = df_avg["latitude"].astype(float).tolist() |
| | longitudes = df_avg["longitude"].astype(float).tolist() |
| | model_label = "Model Average" |
| |
|
| | else: |
| | df_model = df |
| |
|
| | |
| | indicators = df_model[indicator].astype(float).tolist() |
| | latitudes = df_model["latitude"].astype(float).tolist() |
| | longitudes = df_model["longitude"].astype(float).tolist() |
| | model_label = f"Model : {df['model'].unique()[0]}" |
| |
|
| |
|
| | fig.add_trace( |
| | go.Scattermapbox( |
| | lat=latitudes, |
| | lon=longitudes, |
| | mode="markers", |
| | marker=dict( |
| | size=10, |
| | color=indicators, |
| | colorscale="Turbo", |
| | cmin=min(indicators), |
| | cmax=max(indicators), |
| | showscale=True, |
| | ), |
| | text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], |
| | hoverinfo="text" |
| | ) |
| | ) |
| |
|
| | fig.update_layout( |
| | mapbox_style="open-street-map", |
| | mapbox_zoom=3, |
| | mapbox_center={"lat": 46.6, "lon": 2.0}, |
| | coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), |
| | title=f"{indicator_label} in {year} in France ({model_label}) " |
| | ) |
| | return fig |
| |
|
| | return plot_data |
| |
|
| |
|
| | map_of_france_of_indicator_for_given_year: Plot = { |
| | "name": "Map of France of an indicator for a given year", |
| | "description": "Heatmap on the map of France of the values of an in indicator for a given year", |
| | "params": ["indicator_column", "year", "model"], |
| | "plot_function": plot_map_of_france_of_indicator_for_given_year, |
| | "sql_query": indicator_for_given_year_query, |
| | } |
| |
|
| |
|
| | PLOTS = [ |
| | indicator_evolution_at_location, |
| | indicator_number_of_days_per_year_at_location, |
| | distribution_of_indicator_for_given_year, |
| | map_of_france_of_indicator_for_given_year, |
| | ] |
| |
|