Dokumentassistent / src /agent /intent_classifier.py
XQ
Add LLM Provider Fallback
4d2a2da
raw
history blame
2.11 kB
"""Intent classification for incoming user queries."""
import logging
import re
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from src.agent.prompts import get_prompt
from src.models import IntentType
logger = logging.getLogger(__name__)
_THINK_CLOSED_RE = re.compile(r"<think>.*?</think>\s*", re.DOTALL)
_THINK_UNCLOSED_RE = re.compile(r"<think>.*", re.DOTALL)
_VALID_INTENTS = {intent.value for intent in IntentType}
_SYSTEM_PROMPT = get_prompt("intent_classify").template
class IntentClassifier:
"""Classifies user queries into predefined intent categories."""
def __init__(self, llm: BaseChatModel, *, model_name: str = "") -> None:
"""Initialize the intent classifier.
Args:
llm: A LangChain BaseChatModel instance from provider.py.
model_name: Model identifier used to detect models that lack
system-message support (e.g. Gemma via Ollama).
"""
if "gemma3" in model_name.lower():
prompt = ChatPromptTemplate.from_messages([
("human", _SYSTEM_PROMPT + "\n\nQuery: {query}"),
])
else:
prompt = ChatPromptTemplate.from_messages([
("system", _SYSTEM_PROMPT),
("human", "{query}"),
])
self._chain = prompt | llm | StrOutputParser()
def classify(self, query: str) -> IntentType:
"""Classify a user query into an intent type.
Args:
query: The user's natural language query.
Returns:
The classified IntentType.
"""
_raw_out = self._chain.invoke({"query": query})
raw = _THINK_UNCLOSED_RE.sub("", _THINK_CLOSED_RE.sub("", _raw_out)).strip().lower()
logger.debug("Raw classification result: %s", raw)
if raw in _VALID_INTENTS:
return IntentType(raw)
logger.warning("Unrecognized intent '%s', falling back to UNKNOWN", raw)
return IntentType.UNKNOWN