| | import logging |
| | import requests |
| | from typing import Any, Dict, List, Optional, Text |
| |
|
| | from rasa.nlu.classifiers.classifier import IntentClassifier |
| | from rasa.shared.nlu.constants import TEXT, INTENT |
| | from rasa.nlu.config import RasaNLUModelConfig |
| | from rasa.shared.nlu.training_data.training_data import TrainingData |
| | from rasa.shared.nlu.training_data.message import Message |
| | from rasa.nlu.model import Metadata |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class LlmIntentClassifier(IntentClassifier): |
| | """Delegates intent classification to an external HTTP micro-service.""" |
| |
|
| | name = "LlmIntentClassifier" |
| | defaults = { |
| | "classifier_url": "http://classifier:8000/classify", |
| | "timeout": 5.0, |
| | "model_name": None, |
| | "base_url": None, |
| | "class_set": [], |
| | "prompt_template": None, |
| | } |
| |
|
| | def __init__( |
| | self, |
| | component_config: Optional[Dict[Text, Any]] = None, |
| | ) -> None: |
| | super().__init__(component_config or {}) |
| |
|
| | self.url: str = self.component_config.get("classifier_url") |
| | self.timeout: float = float(self.component_config.get("timeout")) |
| | self.model_name: Optional[Text] = self.component_config.get("model_name") |
| | self.base_url: Optional[Text] = self.component_config.get("base_url") |
| | self.class_set: List[Text] = self.component_config.get("class_set", []) |
| | self.prompt_template: Optional[Text] = self.component_config.get("prompt_template") |
| |
|
| | |
| | missing: List[str] = [] |
| | if not self.model_name: |
| | missing.append("model_name") |
| | if not self.base_url: |
| | missing.append("base_url") |
| | if not self.class_set: |
| | missing.append("class_set") |
| | if not self.prompt_template: |
| | missing.append("prompt_template") |
| | if missing: |
| | raise ValueError( |
| | f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier" |
| | ) |
| |
|
| | |
| | self._configure_remote_classifier() |
| |
|
| | def _configure_remote_classifier(self) -> None: |
| | """Send configuration to the classifier backend to initialize the model.""" |
| | payload = { |
| | "model_name": self.model_name, |
| | "base_url": self.base_url, |
| | "class_set": self.class_set, |
| | "prompt_template": self.prompt_template, |
| | } |
| | try: |
| | config_url = self.url.replace("/classify", "/config") |
| | logger.debug(f"Sending classifier config to: {config_url}") |
| | response = requests.post(config_url, json=payload, timeout=self.timeout) |
| | response.raise_for_status() |
| | logger.info("Remote classifier initialized successfully.") |
| | except Exception as e: |
| | logger.warning(f"Failed to initialize remote classifier: {e}") |
| |
|
| | def train( |
| | self, |
| | training_data: TrainingData, |
| | config: Optional[RasaNLUModelConfig] = None, |
| | **kwargs: Any, |
| | ) -> None: |
| | |
| | pass |
| |
|
| | def process(self, message: Message, **kwargs: Any) -> None: |
| | text: Optional[Text] = message.get(TEXT) |
| | intent_name: Optional[Text] = None |
| | confidence: float = 0.0 |
| |
|
| | if text: |
| | payload: Dict[str, Any] = {"message": text} |
| | try: |
| | resp = requests.post(self.url, json=payload, timeout=self.timeout) |
| | resp.raise_for_status() |
| | result = resp.json().get("result") |
| | if isinstance(result, str): |
| | intent_name = result |
| | confidence = 1.0 |
| | except Exception as e: |
| | logger.warning(f"LlmIntentClassifier HTTP error: {e}") |
| |
|
| | message.set(INTENT, {"name": intent_name, "confidence": confidence}, add_to_output=True) |
| |
|
| | def persist( |
| | self, |
| | file_name: Text, |
| | model_dir: Text, |
| | ) -> Optional[Dict[Text, Any]]: |
| | |
| | return { |
| | "classifier_url": self.url, |
| | "timeout": self.timeout, |
| | "model_name": self.model_name, |
| | "base_url": self.base_url, |
| | "class_set": self.class_set, |
| | "prompt_template": self.prompt_template, |
| | } |
| |
|
| | @classmethod |
| | def load( |
| | cls, |
| | meta: Dict[Text, Any], |
| | model_dir: Text, |
| | model_metadata: Metadata = None, |
| | cached_component: Optional["LlmIntentClassifier"] = None, |
| | **kwargs: Any, |
| | ) -> "LlmIntentClassifier": |
| | |
| | return cls(meta) |
| |
|