| import json |
| import logging |
| import os |
| from typing import Dict, Callable, Any |
| import langchain_openai |
| from langchain_core import prompts |
|
|
|
|
| def agent(base_llm_model_name: str) -> Callable[[str, Dict[str, str]], Dict[str, str]]: |
| """Create an agent for query analysis and column mapping. |
| |
| Args: |
| base_llm_model_name: Name of the LLM model to use |
| |
| Returns: |
| Callable: Function that accepts (query, columns_and_descriptions) and returns column-query mapping |
| """ |
| config = { |
| 'model': base_llm_model_name, |
| 'temperature': 0, |
| 'max_tokens': 4000, |
| 'max_retries': 10, |
| 'seed': 123456 |
| } |
| system_prompt = ''' |
| You are a smart assistant that receives: |
| - a user search query with a lot of keywords, |
| - a list of columns extracted from a dataset, |
| - and for each column, its description explaining what it contains. |
| |
| Your task: |
| - Analyze the query. |
| - For each column, determine if part of the query is highly relevant to it. |
| - Extract only the most relevant keywords or parts of the query that fit the topic and meaning of the column. |
| - Output a list of (query fragment, column name) pairs. |
| |
| Rules: |
| - The query fragment must make sense for that specific column. |
| - If the column is not relevant to any part of the query, you can skip it. |
| - Do not modify the meaning of the user's query, but you can split and adapt it into multiple parts. |
| - Be concise but precise in fragment construction. |
| - Include the most important 5-10 columns, maximum. |
| - Does not change the names of the columns. |
| |
| Output format: a JSON object with the key the column names and the values the query fragments. |
| ''' |
|
|
| logging.info(f"Loading model {base_llm_model_name}...") |
| model = langchain_openai.ChatOpenAI( |
| api_key=os.getenv("OPENAI_API_KEY"), |
| model=config['model'], |
| temperature=config['temperature'], |
| max_tokens=config['max_tokens'], |
| max_retries=config['max_retries'], |
| seed=config['seed'], |
| ) |
| prompt = prompts.ChatPromptTemplate.from_messages([ |
| ('system', system_prompt), |
| ('human', 'User Query: {query}, Columns and Descriptions: {columns}'), |
| ]) |
| chain = prompt | model |
|
|
| def invoke(query, columns_and_descriptions): |
| formatted_columns = "\n".join( |
| f"- {col}: {desc}" for col, desc in columns_and_descriptions.items() |
| ) |
| return post_process(chain.invoke({'query': query, 'columns': formatted_columns}), columns_and_descriptions) |
|
|
| return invoke |
|
|
|
|
| def post_process(response: Any, columns_and_descriptions: Dict[str, str]) -> Dict[str, str]: |
| """Post-process LLM response to extract column-query mapping. |
| |
| Args: |
| response: LLM response containing JSON |
| columns_and_descriptions: Dictionary of available columns and descriptions |
| |
| Returns: |
| Dict[str, str]: Dictionary mapping column names to relevant query fragments |
| """ |
| json_response = json.loads(response.content.strip('`').lstrip('json\n')) |
| return {col: json_response[col] for col in columns_and_descriptions if col in json_response} |
|
|