| | import os |
| |
|
| | from typing import Any |
| | import asyncio |
| | from climateqa.engine.llm import get_llm |
| | from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot |
| | from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column |
| | from climateqa.engine.talk_to_data.objects.plot import Plot |
| | from climateqa.engine.talk_to_data.objects.states import State, TTDOutput |
| | from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS |
| | from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS |
| |
|
| | ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd())) |
| |
|
| | async def process_output( |
| | output_title: str, |
| | table: str, |
| | plot: Plot, |
| | params: dict[str, Any] |
| | ) -> tuple[str, TTDOutput, dict[str, bool]]: |
| | """ |
| | Process a table for a given plot and parameters: builds the SQL query, executes it, |
| | and generates the corresponding figure. |
| | |
| | Args: |
| | output_title (str): Title for the output (used as key in outputs dict). |
| | table (str): The name of the table to process. |
| | plot (Plot): The plot object containing SQL query and visualization function. |
| | params (dict[str, Any]): Parameters used for querying the table. |
| | |
| | Returns: |
| | tuple: (output_title, results dict, errors dict) |
| | """ |
| | results: TTDOutput = { |
| | 'status': 'OK', |
| | 'plot': plot, |
| | 'table': table, |
| | 'sql_query': None, |
| | 'dataframe': None, |
| | 'figure': None, |
| | 'plot_information': None, |
| | } |
| | errors = { |
| | 'have_sql_query': False, |
| | 'have_dataframe': False |
| | } |
| |
|
| | |
| | indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE) |
| | if indicator_column: |
| | params['indicator_column'] = indicator_column |
| |
|
| | |
| | sql_query = plot['sql_query'](table, params) |
| | if not sql_query: |
| | results['status'] = 'ERROR' |
| | return output_title, results, errors |
| |
|
| | results['plot_information'] = plot['plot_information'](table, params) |
| |
|
| | results['sql_query'] = sql_query |
| | errors['have_sql_query'] = True |
| |
|
| | |
| | df = await execute_sql_query(sql_query) |
| | if df is not None and not df.empty: |
| | results['dataframe'] = df |
| | errors['have_dataframe'] = True |
| | else: |
| | results['status'] = 'NO_DATA' |
| |
|
| | |
| | results['figure'] = plot['plot_function'](params) |
| |
|
| | return output_title, results, errors |
| |
|
| | async def ipcc_workflow(user_input: str) -> State: |
| | """ |
| | Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures. |
| | |
| | Args: |
| | user_input (str): The user's question. |
| | |
| | Returns: |
| | State: Final state with all the results and error messages if any. |
| | """ |
| | state: State = { |
| | 'user_input': user_input, |
| | 'plots': [], |
| | 'outputs': {}, |
| | 'error': '' |
| | } |
| |
|
| | llm = get_llm(provider="openai") |
| | plots = await find_relevant_plots(state, llm, IPCC_PLOTS) |
| | state['plots'] = plots |
| |
|
| | if not plots: |
| | state['error'] = 'There is no plot to answer to the question' |
| | return state |
| |
|
| | errors = { |
| | 'have_relevant_table': False, |
| | 'have_sql_query': False, |
| | 'have_dataframe': False |
| | } |
| | outputs = {} |
| |
|
| | |
| | for plot_name in plots: |
| | plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None) |
| | if plot is None: |
| | continue |
| |
|
| | relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES) |
| | if relevant_tables: |
| | errors['have_relevant_table'] = True |
| |
|
| | for table in relevant_tables: |
| | output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}" |
| | outputs[output_title] = { |
| | 'table': table, |
| | 'plot': plot, |
| | 'status': 'OK' |
| | } |
| |
|
| | |
| | param_tasks = [ |
| | find_param(state, param_name, mode='IPCC') |
| | for param_name in IPCC_PLOT_PARAMETERS |
| | ] |
| | param_results = await asyncio.gather(*param_tasks) |
| |
|
| | params = {} |
| | for param in param_results: |
| | if param: |
| | params.update(param) |
| | |
| | |
| | tasks = [ |
| | process_output(output_title, output['table'], output['plot'], params.copy()) |
| | for output_title, output in outputs.items() |
| | ] |
| | results = await asyncio.gather(*tasks) |
| |
|
| | |
| | for output_title, task_results, task_errors in results: |
| | outputs[output_title]['sql_query'] = task_results['sql_query'] |
| | outputs[output_title]['dataframe'] = task_results['dataframe'] |
| | outputs[output_title]['figure'] = task_results['figure'] |
| | outputs[output_title]['plot_information'] = task_results['plot_information'] |
| | outputs[output_title]['status'] = task_results['status'] |
| | errors['have_sql_query'] |= task_errors['have_sql_query'] |
| | errors['have_dataframe'] |= task_errors['have_dataframe'] |
| |
|
| | state['outputs'] = outputs |
| |
|
| | |
| | if not errors['have_relevant_table']: |
| | state['error'] = ( |
| | "Sorry, I couldn't find any relevant table in our database to answer your question.\n" |
| | "Try asking about a different climate indicator like temperature or precipitation." |
| | ) |
| | elif not errors['have_sql_query']: |
| | state['error'] = ( |
| | "Sorry, I couldn't generate a relevant SQL query to answer your question.\n" |
| | "Try rephrasing your question to focus on a specific location, a year, or a month." |
| | ) |
| | elif not errors['have_dataframe']: |
| | state['error'] = ( |
| | "Sorry, there is no data in our tables that can answer your question.\n" |
| | "Try asking about a more common location, or a different year." |
| | ) |
| | return state |