File size: 3,129 Bytes
cd6f412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# insucompass/core/agents/query_intent_classifier.py

import logging
import pydantic
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate

from insucompass.core.models import IntentType, QueryIntent
from insucompass.config import settings
from insucompass.prompts.prompt_loader import load_prompt

# Configure logging
logging.basicConfig(level=settings.LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

QUERY_INTENT_PROMPT_TEMPLATE = load_prompt('query_intent_classifier')


class QueryIntentClassifierAgent:
    """
    A classifier that uses an LLM to determine the user's intent and
    pre-processes the query for optimal RAG performance.
    """
    def __init__(self, llm):
        """
        Initializes the classifier with a given LLM.
        Generated code
        Args:
            llm: An instance of a LLM model.
        """
        self.llm = llm
        self.parser = PydanticOutputParser(pydantic_object=QueryIntent)
        self.prompt_template = PromptTemplate(
            template=QUERY_INTENT_PROMPT_TEMPLATE,
            input_variables=["query"],
            partial_variables={
                "format_instructions": self.parser.get_format_instructions(),
                "simple": IntentType.SIMPLE.value,
                "ambiguous": IntentType.AMBIGUOUS.value,
                "complex": IntentType.COMPLEX.value,
                "concise": IntentType.CONCISE.value,
            },
        )
        self.chain = self.prompt_template | self.llm | self.parser
        logger.info("QueryIntentClassifierAgent initialized successfully.")

    def classify_intent(self, query: str) -> QueryIntent:
        """
        Classifies the intent of the query and transforms it accordingly.

        Args:
            query: The user's input query.

        Returns:
            A QueryIntent object containing the classification and transformed queries.
            Returns a default 'SIMPLE' classification on failure.
        """
        if not query:
            logger.warning("Received an empty query. Cannot classify.")
            return QueryIntent(
                intent=IntentType.SIMPLE,
                reasoning="Input query was empty.",
            )

        logger.debug(f"Classifying query: '{query}'")
        try:
            result = self.chain.invoke({"query": query})
            logger.info(f"Successfully classified query. Intent: {result.intent.value}")
            logger.debug(f"Classification reasoning: {result.reasoning}")
            return result
        except (pydantic.ValidationError, Exception) as e:
            logger.error(
                f"Failed to classify query or parse LLM output for query: '{query}'. Error: {e}. "
                "Defaulting to 'SIMPLE' intent."
            )
            # Fallback mechanism to ensure the RAG chain doesn't break
            return QueryIntent(
                intent=IntentType.SIMPLE,
                reasoning="Classification failed due to a parsing or API error. Defaulting to simple retrieval.",
            )