pnrr-data-processor / modules /column_query_agent.py
beppeinthesky's picture
feat: Add cluster analysis and semantic filtering modules
7e85729
import json
import logging
import os
from typing import Dict, Callable, Any
import langchain_openai
from langchain_core import prompts
def agent(base_llm_model_name: str) -> Callable[[str, Dict[str, str]], Dict[str, str]]:
"""Create an agent for query analysis and column mapping.
Args:
base_llm_model_name: Name of the LLM model to use
Returns:
Callable: Function that accepts (query, columns_and_descriptions) and returns column-query mapping
"""
config = {
'model': base_llm_model_name,
'temperature': 0,
'max_tokens': 4000,
'max_retries': 10,
'seed': 123456
}
system_prompt = '''
You are a smart assistant that receives:
- a user search query with a lot of keywords,
- a list of columns extracted from a dataset,
- and for each column, its description explaining what it contains.
Your task:
- Analyze the query.
- For each column, determine if part of the query is highly relevant to it.
- Extract only the most relevant keywords or parts of the query that fit the topic and meaning of the column.
- Output a list of (query fragment, column name) pairs.
Rules:
- The query fragment must make sense for that specific column.
- If the column is not relevant to any part of the query, you can skip it.
- Do not modify the meaning of the user's query, but you can split and adapt it into multiple parts.
- Be concise but precise in fragment construction.
- Include the most important 5-10 columns, maximum.
- Does not change the names of the columns.
Output format: a JSON object with the key the column names and the values the query fragments.
'''
logging.info(f"Loading model {base_llm_model_name}...")
model = langchain_openai.ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
model=config['model'],
temperature=config['temperature'],
max_tokens=config['max_tokens'],
max_retries=config['max_retries'],
seed=config['seed'],
)
prompt = prompts.ChatPromptTemplate.from_messages([
('system', system_prompt),
('human', 'User Query: {query}, Columns and Descriptions: {columns}'),
])
chain = prompt | model
def invoke(query, columns_and_descriptions):
formatted_columns = "\n".join(
f"- {col}: {desc}" for col, desc in columns_and_descriptions.items()
)
return post_process(chain.invoke({'query': query, 'columns': formatted_columns}), columns_and_descriptions)
return invoke
def post_process(response: Any, columns_and_descriptions: Dict[str, str]) -> Dict[str, str]:
"""Post-process LLM response to extract column-query mapping.
Args:
response: LLM response containing JSON
columns_and_descriptions: Dictionary of available columns and descriptions
Returns:
Dict[str, str]: Dictionary mapping column names to relevant query fragments
"""
json_response = json.loads(response.content.strip('`').lstrip('json\n'))
return {col: json_response[col] for col in columns_and_descriptions if col in json_response}