RasaBot / classifier /classifier.py
Luigi's picture
adapt classifier to outlines v1.0.4
9ffbce4
import logging
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import json
import outlines
from outlines.models import openai
from outlines.types import Choice
# Configure logger
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("classifier")
app = FastAPI()
# Global variables for shared config and classifier
llm = None
clf_output = None
prompt_template = None
config_set = False
class Config(BaseModel):
model_name: str
base_url: str
class_set: List[str]
prompt_template: str
class Req(BaseModel):
message: str
class Resp(BaseModel):
result: str
@app.post("/config")
def configure(req: Config):
"""Receive and initialize classifier configuration."""
global llm, clf_output, prompt_template, config_set
if config_set:
logger.warning("Classifier already configured. Ignoring new config.")
return {"status": "already_configured"}
api_key = os.getenv("TOGETHERAI_API_KEY")
logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
try:
# Instantiate model and choice output type
llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
clf_output = Choice(req.class_set)
prompt_template = req.prompt_template
config_set = True
logger.info("Classifier configured successfully.")
return {"status": "configured"}
except Exception as e:
logger.error(f"Failed to configure classifier: {e}")
raise HTTPException(status_code=500, detail="Classifier configuration failed")
@app.post("/classify", response_model=Resp)
def classify(req: Req):
"""Run text classification using the configured LLM."""
global llm, clf_output, prompt_template, config_set
if not config_set or llm is None or clf_output is None:
raise HTTPException(status_code=503, detail="Classifier not configured yet")
# Render prompt
try:
prompt = prompt_template.replace("{message}", req.message)
logger.debug(f"Rendered prompt: {prompt!r}")
except Exception as e:
logger.warning(f"Prompt rendering failed: {e}")
prompt = req.message
# Invoke LLM classifier
try:
raw = llm(prompt, clf_output)
try:
# llm returned '{"result": "Label"}' → parse it
unwrapped = json.loads(raw).get("result", raw)
except Exception:
unwrapped = raw
result = unwrapped
logger.debug(f"Classification result: {result}")
except Exception as e:
logger.error(f"Classification error: {e}. Falling back to: {req.class_set[-1]}")
result = req.class_set[-1]
return Resp(result=result)