File size: 2,108 Bytes
31a2688
 
 
2745e27
31a2688
 
 
 
 
4d2a2da
31a2688
 
 
 
fdc3773
 
2745e27
31a2688
 
4d2a2da
31a2688
 
 
 
 
75620c6
31a2688
 
 
 
75620c6
 
31a2688
75620c6
 
 
 
 
 
 
 
 
31a2688
 
 
 
 
 
 
 
 
 
 
fdc3773
 
31a2688
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Intent classification for incoming user queries."""

import logging
import re

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

from src.agent.prompts import get_prompt
from src.models import IntentType

logger = logging.getLogger(__name__)

_THINK_CLOSED_RE = re.compile(r"<think>.*?</think>\s*", re.DOTALL)
_THINK_UNCLOSED_RE = re.compile(r"<think>.*", re.DOTALL)

_VALID_INTENTS = {intent.value for intent in IntentType}

_SYSTEM_PROMPT = get_prompt("intent_classify").template


class IntentClassifier:
    """Classifies user queries into predefined intent categories."""

    def __init__(self, llm: BaseChatModel, *, model_name: str = "") -> None:
        """Initialize the intent classifier.

        Args:
            llm: A LangChain BaseChatModel instance from provider.py.
            model_name: Model identifier used to detect models that lack
                system-message support (e.g. Gemma via Ollama).
        """
        if "gemma3" in model_name.lower():
            prompt = ChatPromptTemplate.from_messages([
                ("human", _SYSTEM_PROMPT + "\n\nQuery: {query}"),
            ])
        else:
            prompt = ChatPromptTemplate.from_messages([
                ("system", _SYSTEM_PROMPT),
                ("human", "{query}"),
            ])
        self._chain = prompt | llm | StrOutputParser()

    def classify(self, query: str) -> IntentType:
        """Classify a user query into an intent type.

        Args:
            query: The user's natural language query.

        Returns:
            The classified IntentType.
        """
        _raw_out = self._chain.invoke({"query": query})
        raw = _THINK_UNCLOSED_RE.sub("", _THINK_CLOSED_RE.sub("", _raw_out)).strip().lower()
        logger.debug("Raw classification result: %s", raw)

        if raw in _VALID_INTENTS:
            return IntentType(raw)

        logger.warning("Unrecognized intent '%s', falling back to UNKNOWN", raw)
        return IntentType.UNKNOWN