File size: 3,247 Bytes
7e85729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import logging
import os
from typing import Dict, Callable, Any
import langchain_openai
from langchain_core import prompts


def agent(base_llm_model_name: str) -> Callable[[str, Dict[str, str]], Dict[str, str]]:
    """Create an agent for query analysis and column mapping.
    
    Args:
        base_llm_model_name: Name of the LLM model to use
        
    Returns:
        Callable: Function that accepts (query, columns_and_descriptions) and returns column-query mapping
    """
    config = {
        'model': base_llm_model_name,
        'temperature': 0,
        'max_tokens': 4000,
        'max_retries': 10,
        'seed': 123456
    }
    system_prompt = '''
        You are a smart assistant that receives:
        - a user search query with a lot of keywords,
        - a list of columns extracted from a dataset,
        - and for each column, its description explaining what it contains.

        Your task:
        - Analyze the query.
        - For each column, determine if part of the query is highly relevant to it.
        - Extract only the most relevant keywords or parts of the query that fit the topic and meaning of the column.
        - Output a list of (query fragment, column name) pairs.

        Rules:
        - The query fragment must make sense for that specific column.
        - If the column is not relevant to any part of the query, you can skip it.
        - Do not modify the meaning of the user's query, but you can split and adapt it into multiple parts.
        - Be concise but precise in fragment construction.
        - Include the most important 5-10 columns, maximum.
        - Does not change the names of the columns.

        Output format: a JSON object with the key the column names and the values the query fragments.
    '''

    logging.info(f"Loading model {base_llm_model_name}...")
    model = langchain_openai.ChatOpenAI(
        api_key=os.getenv("OPENAI_API_KEY"),
        model=config['model'],
        temperature=config['temperature'],
        max_tokens=config['max_tokens'],
        max_retries=config['max_retries'],
        seed=config['seed'],
    )
    prompt = prompts.ChatPromptTemplate.from_messages([
        ('system', system_prompt),
        ('human', 'User Query: {query}, Columns and Descriptions: {columns}'),
    ])
    chain = prompt | model

    def invoke(query, columns_and_descriptions):
        formatted_columns = "\n".join(
            f"- {col}: {desc}" for col, desc in columns_and_descriptions.items()
        )
        return post_process(chain.invoke({'query': query, 'columns': formatted_columns}), columns_and_descriptions)

    return invoke


def post_process(response: Any, columns_and_descriptions: Dict[str, str]) -> Dict[str, str]:
    """Post-process LLM response to extract column-query mapping.
    
    Args:
        response: LLM response containing JSON
        columns_and_descriptions: Dictionary of available columns and descriptions
        
    Returns:
        Dict[str, str]: Dictionary mapping column names to relevant query fragments
    """
    json_response = json.loads(response.content.strip('`').lstrip('json\n'))
    return {col: json_response[col] for col in columns_and_descriptions if col in json_response}