import json import logging import os from typing import Dict, Callable, Any import langchain_openai from langchain_core import prompts def agent(base_llm_model_name: str) -> Callable[[str, Dict[str, str]], Dict[str, str]]: """Create an agent for query analysis and column mapping. Args: base_llm_model_name: Name of the LLM model to use Returns: Callable: Function that accepts (query, columns_and_descriptions) and returns column-query mapping """ config = { 'model': base_llm_model_name, 'temperature': 0, 'max_tokens': 4000, 'max_retries': 10, 'seed': 123456 } system_prompt = ''' You are a smart assistant that receives: - a user search query with a lot of keywords, - a list of columns extracted from a dataset, - and for each column, its description explaining what it contains. Your task: - Analyze the query. - For each column, determine if part of the query is highly relevant to it. - Extract only the most relevant keywords or parts of the query that fit the topic and meaning of the column. - Output a list of (query fragment, column name) pairs. Rules: - The query fragment must make sense for that specific column. - If the column is not relevant to any part of the query, you can skip it. - Do not modify the meaning of the user's query, but you can split and adapt it into multiple parts. - Be concise but precise in fragment construction. - Include the most important 5-10 columns, maximum. - Does not change the names of the columns. Output format: a JSON object with the key the column names and the values the query fragments. ''' logging.info(f"Loading model {base_llm_model_name}...") model = langchain_openai.ChatOpenAI( api_key=os.getenv("OPENAI_API_KEY"), model=config['model'], temperature=config['temperature'], max_tokens=config['max_tokens'], max_retries=config['max_retries'], seed=config['seed'], ) prompt = prompts.ChatPromptTemplate.from_messages([ ('system', system_prompt), ('human', 'User Query: {query}, Columns and Descriptions: {columns}'), ]) chain = prompt | model def invoke(query, columns_and_descriptions): formatted_columns = "\n".join( f"- {col}: {desc}" for col, desc in columns_and_descriptions.items() ) return post_process(chain.invoke({'query': query, 'columns': formatted_columns}), columns_and_descriptions) return invoke def post_process(response: Any, columns_and_descriptions: Dict[str, str]) -> Dict[str, str]: """Post-process LLM response to extract column-query mapping. Args: response: LLM response containing JSON columns_and_descriptions: Dictionary of available columns and descriptions Returns: Dict[str, str]: Dictionary mapping column names to relevant query fragments """ json_response = json.loads(response.content.strip('`').lstrip('json\n')) return {col: json_response[col] for col in columns_and_descriptions if col in json_response}