performance improvement: initialize llm agent only once after startup
Browse files
classifier/classifier.py
CHANGED
|
@@ -1,69 +1,79 @@
|
|
| 1 |
import logging
|
| 2 |
-
from fastapi import FastAPI
|
| 3 |
-
from pydantic import BaseModel
|
| 4 |
-
from typing import List
|
| 5 |
-
import os
|
| 6 |
-
from string import Formatter
|
| 7 |
-
|
| 8 |
import os
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
import outlines
|
| 11 |
from outlines.models import openai
|
| 12 |
from outlines.generate import choice
|
| 13 |
|
| 14 |
# Configure logger
|
| 15 |
-
|
| 16 |
-
tools.setLevel(logging.DEBUG)
|
| 17 |
-
ch = logging.StreamHandler()
|
| 18 |
-
ch.setLevel(logging.DEBUG)
|
| 19 |
-
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
|
| 20 |
-
ch.setFormatter(formatter)
|
| 21 |
-
tools.addHandler(ch)
|
| 22 |
-
|
| 23 |
-
# Configure logger
|
| 24 |
-
logging.basicConfig(
|
| 25 |
-
format="%(asctime)s %(levelname)s:%(name)s: %(message)s",
|
| 26 |
-
level=logging.DEBUG,
|
| 27 |
-
)
|
| 28 |
logger = logging.getLogger("classifier")
|
| 29 |
|
| 30 |
app = FastAPI()
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
model_name: str
|
| 36 |
base_url: str
|
| 37 |
class_set: List[str]
|
| 38 |
-
prompt_template: str
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
class Resp(BaseModel):
|
| 41 |
result: str
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
prompt = req.prompt_template.replace("{message}", req.message)
|
| 53 |
-
logger.debug(f"Rendered prompt: {prompt!r}")
|
| 54 |
|
| 55 |
api_key = os.getenv("TOGETHERAI_API_KEY")
|
| 56 |
logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
|
| 57 |
-
llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
|
| 58 |
-
clf = choice(llm, req.class_set)
|
| 59 |
-
logger.debug(f"Choice classifier created with labels: {req.class_set}")
|
| 60 |
|
| 61 |
try:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
except Exception as e:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
return Resp(result=result)
|
|
|
|
| 1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
+
from fastapi import FastAPI, HTTPException
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from typing import List, Optional
|
| 6 |
|
| 7 |
import outlines
|
| 8 |
from outlines.models import openai
|
| 9 |
from outlines.generate import choice
|
| 10 |
|
| 11 |
# Configure logger
|
| 12 |
+
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
logger = logging.getLogger("classifier")
|
| 14 |
|
| 15 |
app = FastAPI()
|
| 16 |
|
| 17 |
+
# Global variables for shared config and classifier
|
| 18 |
+
clf = None
|
| 19 |
+
config_set = False
|
| 20 |
+
|
| 21 |
+
class Config(BaseModel):
|
| 22 |
model_name: str
|
| 23 |
base_url: str
|
| 24 |
class_set: List[str]
|
| 25 |
+
prompt_template: str
|
| 26 |
+
|
| 27 |
+
class Req(BaseModel):
|
| 28 |
+
message: str
|
| 29 |
|
| 30 |
class Resp(BaseModel):
|
| 31 |
result: str
|
| 32 |
|
| 33 |
+
@app.post("/config")
|
| 34 |
+
def configure(req: Config):
|
| 35 |
+
"""Receive and initialize classifier configuration."""
|
| 36 |
+
global clf, config_set
|
| 37 |
|
| 38 |
+
if config_set:
|
| 39 |
+
logger.warning("Classifier already configured. Ignoring new config.")
|
| 40 |
+
return {"status": "already_configured"}
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
api_key = os.getenv("TOGETHERAI_API_KEY")
|
| 43 |
logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
try:
|
| 46 |
+
llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
|
| 47 |
+
clf = choice(llm, req.class_set)
|
| 48 |
+
clf.class_set = req.class_set
|
| 49 |
+
clf.prompt_template = req.prompt_template
|
| 50 |
+
config_set = True
|
| 51 |
+
logger.info("Classifier configured successfully.")
|
| 52 |
+
return {"status": "configured"}
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Failed to configure classifier: {e}")
|
| 55 |
+
raise HTTPException(status_code=500, detail="Classifier configuration failed")
|
| 56 |
+
|
| 57 |
+
@app.post("/classify", response_model=Resp)
|
| 58 |
+
def classify(req: Req):
|
| 59 |
+
global clf
|
| 60 |
+
if clf is None or not config_set:
|
| 61 |
+
raise HTTPException(status_code=503, detail="Classifier not configured yet")
|
| 62 |
+
|
| 63 |
+
# Render the prompt using the template
|
| 64 |
+
try:
|
| 65 |
+
prompt = clf.prompt_template.replace("{message}", req.message)
|
| 66 |
+
logger.debug(f"Rendered prompt: {prompt!r}")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.warning(f"Prompt rendering failed: {e}")
|
| 69 |
+
prompt = req.message
|
| 70 |
+
|
| 71 |
+
# Run classifier
|
| 72 |
+
try:
|
| 73 |
+
result = clf(prompt)
|
| 74 |
+
logger.debug(f"Classification result: {result}")
|
| 75 |
except Exception as e:
|
| 76 |
+
logger.error(f"Classification error: {e}. Falling back to: {clf.class_set[-1]}")
|
| 77 |
+
result = clf.class_set[-1]
|
| 78 |
|
| 79 |
return Resp(result=result)
|
custom_components/llm_intent_classifier_client.py
CHANGED
|
@@ -30,6 +30,7 @@ class LlmIntentClassifier(IntentClassifier):
|
|
| 30 |
component_config: Optional[Dict[Text, Any]] = None,
|
| 31 |
) -> None:
|
| 32 |
super().__init__(component_config or {})
|
|
|
|
| 33 |
self.url: str = self.component_config.get("classifier_url")
|
| 34 |
self.timeout: float = float(self.component_config.get("timeout"))
|
| 35 |
self.model_name: Optional[Text] = self.component_config.get("model_name")
|
|
@@ -52,6 +53,26 @@ class LlmIntentClassifier(IntentClassifier):
|
|
| 52 |
f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier"
|
| 53 |
)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def train(
|
| 56 |
self,
|
| 57 |
training_data: TrainingData,
|
|
@@ -67,13 +88,7 @@ class LlmIntentClassifier(IntentClassifier):
|
|
| 67 |
confidence: float = 0.0
|
| 68 |
|
| 69 |
if text:
|
| 70 |
-
payload: Dict[str, Any] = {
|
| 71 |
-
"message": text,
|
| 72 |
-
"model_name": self.model_name,
|
| 73 |
-
"base_url": self.base_url,
|
| 74 |
-
"class_set": self.class_set,
|
| 75 |
-
"prompt_template": self.prompt_template,
|
| 76 |
-
}
|
| 77 |
try:
|
| 78 |
resp = requests.post(self.url, json=payload, timeout=self.timeout)
|
| 79 |
resp.raise_for_status()
|
|
|
|
| 30 |
component_config: Optional[Dict[Text, Any]] = None,
|
| 31 |
) -> None:
|
| 32 |
super().__init__(component_config or {})
|
| 33 |
+
|
| 34 |
self.url: str = self.component_config.get("classifier_url")
|
| 35 |
self.timeout: float = float(self.component_config.get("timeout"))
|
| 36 |
self.model_name: Optional[Text] = self.component_config.get("model_name")
|
|
|
|
| 53 |
f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier"
|
| 54 |
)
|
| 55 |
|
| 56 |
+
# Push config to classifier backend
|
| 57 |
+
self._configure_remote_classifier()
|
| 58 |
+
|
| 59 |
+
def _configure_remote_classifier(self) -> None:
|
| 60 |
+
"""Send configuration to the classifier backend to initialize the model."""
|
| 61 |
+
payload = {
|
| 62 |
+
"model_name": self.model_name,
|
| 63 |
+
"base_url": self.base_url,
|
| 64 |
+
"class_set": self.class_set,
|
| 65 |
+
"prompt_template": self.prompt_template,
|
| 66 |
+
}
|
| 67 |
+
try:
|
| 68 |
+
config_url = self.url.replace("/classify", "/config")
|
| 69 |
+
logger.debug(f"Sending classifier config to: {config_url}")
|
| 70 |
+
response = requests.post(config_url, json=payload, timeout=self.timeout)
|
| 71 |
+
response.raise_for_status()
|
| 72 |
+
logger.info("Remote classifier initialized successfully.")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.warning(f"Failed to initialize remote classifier: {e}")
|
| 75 |
+
|
| 76 |
def train(
|
| 77 |
self,
|
| 78 |
training_data: TrainingData,
|
|
|
|
| 88 |
confidence: float = 0.0
|
| 89 |
|
| 90 |
if text:
|
| 91 |
+
payload: Dict[str, Any] = {"message": text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
try:
|
| 93 |
resp = requests.post(self.url, json=payload, timeout=self.timeout)
|
| 94 |
resp.raise_for_status()
|