RasaBot / custom_components /llm_intent_classifier_client.py
Luigi's picture
performance improvement: initialize llm agent only once after startup
ef8c7e3
import logging
import requests
from typing import Any, Dict, List, Optional, Text
from rasa.nlu.classifiers.classifier import IntentClassifier
from rasa.shared.nlu.constants import TEXT, INTENT
from rasa.nlu.config import RasaNLUModelConfig
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
from rasa.nlu.model import Metadata
logger = logging.getLogger(__name__)
class LlmIntentClassifier(IntentClassifier):
"""Delegates intent classification to an external HTTP micro-service."""
name = "LlmIntentClassifier"
defaults = {
"classifier_url": "http://classifier:8000/classify",
"timeout": 5.0,
"model_name": None,
"base_url": None,
"class_set": [],
"prompt_template": None,
}
def __init__(
self,
component_config: Optional[Dict[Text, Any]] = None,
) -> None:
super().__init__(component_config or {})
self.url: str = self.component_config.get("classifier_url")
self.timeout: float = float(self.component_config.get("timeout"))
self.model_name: Optional[Text] = self.component_config.get("model_name")
self.base_url: Optional[Text] = self.component_config.get("base_url")
self.class_set: List[Text] = self.component_config.get("class_set", [])
self.prompt_template: Optional[Text] = self.component_config.get("prompt_template")
# Validate required configuration
missing: List[str] = []
if not self.model_name:
missing.append("model_name")
if not self.base_url:
missing.append("base_url")
if not self.class_set:
missing.append("class_set")
if not self.prompt_template:
missing.append("prompt_template")
if missing:
raise ValueError(
f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier"
)
# Push config to classifier backend
self._configure_remote_classifier()
def _configure_remote_classifier(self) -> None:
"""Send configuration to the classifier backend to initialize the model."""
payload = {
"model_name": self.model_name,
"base_url": self.base_url,
"class_set": self.class_set,
"prompt_template": self.prompt_template,
}
try:
config_url = self.url.replace("/classify", "/config")
logger.debug(f"Sending classifier config to: {config_url}")
response = requests.post(config_url, json=payload, timeout=self.timeout)
response.raise_for_status()
logger.info("Remote classifier initialized successfully.")
except Exception as e:
logger.warning(f"Failed to initialize remote classifier: {e}")
def train(
self,
training_data: TrainingData,
config: Optional[RasaNLUModelConfig] = None,
**kwargs: Any,
) -> None:
# No local training; this uses a remote service
pass
def process(self, message: Message, **kwargs: Any) -> None:
text: Optional[Text] = message.get(TEXT)
intent_name: Optional[Text] = None
confidence: float = 0.0
if text:
payload: Dict[str, Any] = {"message": text}
try:
resp = requests.post(self.url, json=payload, timeout=self.timeout)
resp.raise_for_status()
result = resp.json().get("result")
if isinstance(result, str):
intent_name = result
confidence = 1.0
except Exception as e:
logger.warning(f"LlmIntentClassifier HTTP error: {e}")
message.set(INTENT, {"name": intent_name, "confidence": confidence}, add_to_output=True)
def persist(
self,
file_name: Text,
model_dir: Text,
) -> Optional[Dict[Text, Any]]:
# Save configuration so it can be reloaded
return {
"classifier_url": self.url,
"timeout": self.timeout,
"model_name": self.model_name,
"base_url": self.base_url,
"class_set": self.class_set,
"prompt_template": self.prompt_template,
}
@classmethod
def load(
cls,
meta: Dict[Text, Any],
model_dir: Text,
model_metadata: Metadata = None,
cached_component: Optional["LlmIntentClassifier"] = None,
**kwargs: Any,
) -> "LlmIntentClassifier":
# meta contains saved configuration
return cls(meta)