| | from typing import Any, Literal, Optional, cast |
| | import ast |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from geopy.geocoders import Nominatim |
| | from climateqa.engine.llm import get_llm |
| | import duckdb |
| | import os |
| | from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH |
| | from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput |
| | from climateqa.engine.talk_to_data.objects.location import Location |
| | from climateqa.engine.talk_to_data.objects.plot import Plot |
| | from climateqa.engine.talk_to_data.objects.states import State |
| | import calendar |
| |
|
| | async def detect_location_with_openai(sentence: str) -> str: |
| | """ |
| | Detects locations in a sentence using OpenAI's API via LangChain. |
| | """ |
| | llm = get_llm() |
| |
|
| | prompt = f""" |
| | Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence. |
| | Return the result as a Python list. If no locations are mentioned, return an empty list. |
| | |
| | Sentence: "{sentence}" |
| | """ |
| |
|
| | response = await llm.ainvoke(prompt) |
| | location_list = ast.literal_eval(response.content.strip("```python\n").strip()) |
| | if location_list: |
| | return location_list[0] |
| | else: |
| | return "" |
| |
|
| | def loc_to_coords(location: str) -> tuple[float, float]: |
| | """Converts a location name to geographic coordinates. |
| | |
| | This function uses the Nominatim geocoding service to convert |
| | a location name (e.g., city name) to its latitude and longitude. |
| | |
| | Args: |
| | location (str): The name of the location to geocode |
| | |
| | Returns: |
| | tuple[float, float]: A tuple containing (latitude, longitude) |
| | |
| | Raises: |
| | AttributeError: If the location cannot be found |
| | """ |
| | geolocator = Nominatim(user_agent="city_to_latlong", timeout=5) |
| | coords = geolocator.geocode(location) |
| | return (coords.latitude, coords.longitude) |
| |
|
| | def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]: |
| | """Converts geographic coordinates to a country name. |
| | |
| | This function uses the Nominatim reverse geocoding service to convert |
| | latitude and longitude coordinates to a country name. |
| | |
| | Args: |
| | coords (tuple[float, float]): A tuple containing (latitude, longitude) |
| | |
| | Returns: |
| | tuple[str,str]: A tuple containg (country_code, country_name, admin1) |
| | |
| | Raises: |
| | AttributeError: If the coordinates cannot be found |
| | """ |
| | geolocator = Nominatim(user_agent="latlong_to_country") |
| | location = geolocator.reverse(coords) |
| | address = location.raw['address'] |
| | return address['country_code'].upper(), address['country'] |
| |
|
| | def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]: |
| | long = round(location[1], 3) |
| | lat = round(location[0], 3) |
| | conn = duckdb.connect() |
| |
|
| | if mode == 'DRIAS': |
| | table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'" |
| | results = conn.sql( |
| | f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}" |
| | ).fetchdf() |
| | else: |
| | table_path = f"'{IPCC_COORDINATES_PATH}'" |
| | results = conn.sql( |
| | f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}" |
| | ).fetchdf() |
| | |
| |
|
| | if len(results) == 0: |
| | return "", "", "" |
| |
|
| | if 'admin1' in results.columns: |
| | admin1 = results['admin1'].iloc[0] |
| | else: |
| | admin1 = None |
| | return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1 |
| |
|
| | async def detect_year_with_openai(sentence: str) -> str: |
| | """ |
| | Detects years in a sentence using OpenAI's API via LangChain. |
| | """ |
| | llm = get_llm() |
| |
|
| | prompt = """ |
| | Extract all years mentioned in the following sentence. |
| | Return the result as a Python list. If no year are mentioned, return an empty list. |
| | |
| | Sentence: "{sentence}" |
| | """ |
| |
|
| | prompt = ChatPromptTemplate.from_template(prompt) |
| | structured_llm = llm.with_structured_output(ArrayOutput) |
| | chain = prompt | structured_llm |
| | response: ArrayOutput = await chain.ainvoke({"sentence": sentence}) |
| | years_list = ast.literal_eval(response['array']) |
| | if len(years_list) > 0: |
| | return years_list[0] |
| | else: |
| | return "" |
| |
|
| |
|
| | async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]: |
| | """Identifies relevant tables for a plot based on user input. |
| | |
| | This function uses an LLM to analyze the user's question and the plot |
| | description to determine which tables in the DRIAS database would be |
| | most relevant for generating the requested visualization. |
| | |
| | Args: |
| | user_question (str): The user's question about climate data |
| | plot (Plot): The plot configuration object |
| | llm: The language model instance to use for analysis |
| | |
| | Returns: |
| | list[str]: A list of table names that are relevant for the plot |
| | |
| | Example: |
| | >>> detect_relevant_tables( |
| | ... "What will the temperature be like in Paris?", |
| | ... indicator_evolution_at_location, |
| | ... llm |
| | ... ) |
| | ['mean_annual_temperature', 'mean_summer_temperature'] |
| | """ |
| | |
| |
|
| | prompt = ( |
| | f"You are helping to build a plot following this description : {plot['description']}." |
| | f"You are given a list of tables and a user question." |
| | f"Based on the description of the plot, which table are appropriate for that kind of plot." |
| | f"Write the 3 most relevant tables to use. Answer only a python list of table name." |
| | f"### List of tables : {table_names_list}" |
| | f"### User question : {user_question}" |
| | f"### List of table name : " |
| | ) |
| |
|
| | table_names = ast.literal_eval( |
| | (await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
| | ) |
| | return table_names |
| |
|
| | async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]: |
| | plots_description = "" |
| | for plot in plot_list: |
| | plots_description += "Name: " + plot["name"] |
| | plots_description += " - Description: " + plot["description"] + "\n" |
| |
|
| | prompt = ( |
| | "You are helping to answer a question with insightful visualizations.\n" |
| | "You are given a user question and a list of plots with their name and description.\n" |
| | "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. " |
| | "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n" |
| | "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n" |
| | "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n" |
| | f"### Descriptions of the plots : {plots_description}" |
| | f"### User question : {user_question}\n" |
| | f"### Names of the plots : " |
| | ) |
| |
|
| | plot_names = ast.literal_eval( |
| | (await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
| | ) |
| | return plot_names |
| |
|
| | async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location: |
| | print(f"---- Find location in user input ----") |
| | location = await detect_location_with_openai(user_input) |
| | output: Location = { |
| | 'location' : location, |
| | 'longitude' : None, |
| | 'latitude' : None, |
| | 'country_code' : None, |
| | 'country_name' : None, |
| | 'admin1' : None |
| | } |
| | |
| | if location: |
| | coords = loc_to_coords(location) |
| | country_code, country_name = coords_to_country(coords) |
| | neighbour = nearest_neighbour_sql(coords, mode) |
| | output.update({ |
| | "latitude": neighbour[0], |
| | "longitude": neighbour[1], |
| | "country_code": country_code, |
| | "country_name": country_name, |
| | "admin1": neighbour[2] |
| | }) |
| | output = cast(Location, output) |
| | return output |
| |
|
| | async def find_year(user_input: str) -> str| None: |
| | """Extracts year information from user input using LLM. |
| | |
| | This function uses an LLM to identify and extract year information from the |
| | user's query, which is used to filter data in subsequent queries. |
| | |
| | Args: |
| | user_input (str): The user's query text |
| | |
| | Returns: |
| | str: The extracted year, or empty string if no year found |
| | """ |
| | print(f"---- Find year ---") |
| | year = await detect_year_with_openai(user_input) |
| | if year == "": |
| | return None |
| | return year |
| |
|
| | async def find_month(user_input: str) -> dict[str, str|None]: |
| | """ |
| | Extracts month information from user input using an LLM. |
| | |
| | This function analyzes the user's query to detect if a month is mentioned. |
| | It returns both the month number (as a string, e.g. '7' for July) and the full English month name (e.g. 'July'). |
| | If no month is found, both values will be None. |
| | |
| | Args: |
| | user_input (str): The user's query text. |
| | |
| | Returns: |
| | dict[str, str|None]: A dictionary with keys: |
| | - "month_number": the month number as a string (e.g. '7'), or None if not found |
| | - "month_name": the full English month name (e.g. 'July'), or None if not found |
| | |
| | Example: |
| | >>> await find_month("Show me the temperature in Paris in July") |
| | {'month_number': '7', 'month_name': 'July'} |
| | >>> await find_month("Show me the temperature in Paris") |
| | {'month_number': None, 'month_name': None} |
| | """ |
| |
|
| | llm = get_llm() |
| | prompt = """ |
| | Extract the month (as a number from 1 to 12) mentioned in the following sentence. |
| | Return the result as a Python list of integers. If no month is mentioned, return an empty list. |
| | |
| | Sentence: "{sentence}" |
| | """ |
| | prompt = ChatPromptTemplate.from_template(prompt) |
| | structured_llm = llm.with_structured_output(ArrayOutput) |
| | chain = prompt | structured_llm |
| | response: ArrayOutput = await chain.ainvoke({"sentence": user_input}) |
| | months_list = ast.literal_eval(response['array']) |
| | if len(months_list) > 0: |
| | month_number = int(months_list[0]) |
| | month_name = calendar.month_name[month_number] |
| | return { |
| | "month_number": str(month_number), |
| | "month_name": month_name |
| | } |
| | else: |
| | return { |
| | "month_number" : None, |
| | "month_name" : None |
| | } |
| |
|
| |
|
| | async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]: |
| | print("---- Find relevant plots ----") |
| | relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots) |
| | return relevant_plots |
| |
|
| | async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]: |
| | print(f"---- Find relevant tables for {plot['name']} ----") |
| | relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables) |
| | return relevant_tables |
| |
|
| | async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None: |
| | """ |
| | Retrieves a specific parameter (location, year, month, etc.) from the user's input using the appropriate extraction method. |
| | |
| | Args: |
| | state (State): The current state containing at least the user's input under 'user_input'. |
| | param_name (str): The name of the parameter to extract. Supported: 'location', 'year', 'month'. |
| | mode (Literal['DRIAS', 'IPCC']): The data mode to use for location extraction. |
| | |
| | Returns: |
| | - For 'location': a Location object (dict with keys like 'location', 'latitude', etc.), or None if not found. |
| | - For 'year': a dict {'year': year or None}. |
| | - For 'month': a dict {'month_number': str or None, 'month_name': str or None}. |
| | - None if the parameter is not recognized or not found. |
| | |
| | Example: |
| | >>> await find_param(state, 'location') |
| | {'location': 'Paris', 'latitude': ..., ...} |
| | >>> await find_param(state, 'year') |
| | {'year': '2050'} |
| | >>> await find_param(state, 'month') |
| | {'month_number': '7', 'month_name': 'July'} |
| | """ |
| | if param_name == 'location': |
| | location = await find_location(state['user_input'], mode) |
| | return location |
| | if param_name == 'year': |
| | year = await find_year(state['user_input']) |
| | return {'year': year} |
| | if param_name == 'month': |
| | month = await find_month(state['user_input']) |
| | return month |
| | return None |