File size: 4,650 Bytes
0d2a25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef8c7e3
0d2a25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef8c7e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d2a25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef8c7e3
0d2a25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
import requests
from typing import Any, Dict, List, Optional, Text

from rasa.nlu.classifiers.classifier import IntentClassifier
from rasa.shared.nlu.constants import TEXT, INTENT
from rasa.nlu.config import RasaNLUModelConfig
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
from rasa.nlu.model import Metadata

logger = logging.getLogger(__name__)


class LlmIntentClassifier(IntentClassifier):
    """Delegates intent classification to an external HTTP micro-service."""

    name = "LlmIntentClassifier"
    defaults = {
        "classifier_url": "http://classifier:8000/classify",
        "timeout": 5.0,
        "model_name": None,
        "base_url": None,
        "class_set": [],
        "prompt_template": None,
    }

    def __init__(
        self,
        component_config: Optional[Dict[Text, Any]] = None,
    ) -> None:
        super().__init__(component_config or {})

        self.url: str = self.component_config.get("classifier_url")
        self.timeout: float = float(self.component_config.get("timeout"))
        self.model_name: Optional[Text] = self.component_config.get("model_name")
        self.base_url: Optional[Text] = self.component_config.get("base_url")
        self.class_set: List[Text] = self.component_config.get("class_set", [])
        self.prompt_template: Optional[Text] = self.component_config.get("prompt_template")

        # Validate required configuration
        missing: List[str] = []
        if not self.model_name:
            missing.append("model_name")
        if not self.base_url:
            missing.append("base_url")
        if not self.class_set:
            missing.append("class_set")
        if not self.prompt_template:
            missing.append("prompt_template")
        if missing:
            raise ValueError(
                f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier"
            )

        # Push config to classifier backend
        self._configure_remote_classifier()

    def _configure_remote_classifier(self) -> None:
        """Send configuration to the classifier backend to initialize the model."""
        payload = {
            "model_name": self.model_name,
            "base_url": self.base_url,
            "class_set": self.class_set,
            "prompt_template": self.prompt_template,
        }
        try:
            config_url = self.url.replace("/classify", "/config")
            logger.debug(f"Sending classifier config to: {config_url}")
            response = requests.post(config_url, json=payload, timeout=self.timeout)
            response.raise_for_status()
            logger.info("Remote classifier initialized successfully.")
        except Exception as e:
            logger.warning(f"Failed to initialize remote classifier: {e}")

    def train(
        self,
        training_data: TrainingData,
        config: Optional[RasaNLUModelConfig] = None,
        **kwargs: Any,
    ) -> None:
        # No local training; this uses a remote service
        pass

    def process(self, message: Message, **kwargs: Any) -> None:
        text: Optional[Text] = message.get(TEXT)
        intent_name: Optional[Text] = None
        confidence: float = 0.0

        if text:
            payload: Dict[str, Any] = {"message": text}
            try:
                resp = requests.post(self.url, json=payload, timeout=self.timeout)
                resp.raise_for_status()
                result = resp.json().get("result")
                if isinstance(result, str):
                    intent_name = result
                    confidence = 1.0
            except Exception as e:
                logger.warning(f"LlmIntentClassifier HTTP error: {e}")

        message.set(INTENT, {"name": intent_name, "confidence": confidence}, add_to_output=True)

    def persist(
        self,
        file_name: Text,
        model_dir: Text,
    ) -> Optional[Dict[Text, Any]]:
        # Save configuration so it can be reloaded
        return {
            "classifier_url": self.url,
            "timeout": self.timeout,
            "model_name": self.model_name,
            "base_url": self.base_url,
            "class_set": self.class_set,
            "prompt_template": self.prompt_template,
        }

    @classmethod
    def load(
        cls,
        meta: Dict[Text, Any],
        model_dir: Text,
        model_metadata: Metadata = None,
        cached_component: Optional["LlmIntentClassifier"] = None,
        **kwargs: Any,
    ) -> "LlmIntentClassifier":
        # meta contains saved configuration
        return cls(meta)