| | import logging |
| | import os |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | from typing import List |
| | import json |
| |
|
| | import outlines |
| | from outlines.models import openai |
| | from outlines.types import Choice |
| |
|
| | |
| | logging.basicConfig(level=logging.DEBUG) |
| | logger = logging.getLogger("classifier") |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | llm = None |
| | clf_output = None |
| | prompt_template = None |
| | config_set = False |
| |
|
| | class Config(BaseModel): |
| | model_name: str |
| | base_url: str |
| | class_set: List[str] |
| | prompt_template: str |
| |
|
| | class Req(BaseModel): |
| | message: str |
| |
|
| | class Resp(BaseModel): |
| | result: str |
| |
|
| | @app.post("/config") |
| | def configure(req: Config): |
| | """Receive and initialize classifier configuration.""" |
| | global llm, clf_output, prompt_template, config_set |
| |
|
| | if config_set: |
| | logger.warning("Classifier already configured. Ignoring new config.") |
| | return {"status": "already_configured"} |
| |
|
| | api_key = os.getenv("TOGETHERAI_API_KEY") |
| | logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}") |
| |
|
| | try: |
| | |
| | llm = openai(req.model_name, api_key=api_key, base_url=req.base_url) |
| | clf_output = Choice(req.class_set) |
| | prompt_template = req.prompt_template |
| | config_set = True |
| |
|
| | logger.info("Classifier configured successfully.") |
| | return {"status": "configured"} |
| | except Exception as e: |
| | logger.error(f"Failed to configure classifier: {e}") |
| | raise HTTPException(status_code=500, detail="Classifier configuration failed") |
| |
|
| | @app.post("/classify", response_model=Resp) |
| | def classify(req: Req): |
| | """Run text classification using the configured LLM.""" |
| | global llm, clf_output, prompt_template, config_set |
| |
|
| | if not config_set or llm is None or clf_output is None: |
| | raise HTTPException(status_code=503, detail="Classifier not configured yet") |
| |
|
| | |
| | try: |
| | prompt = prompt_template.replace("{message}", req.message) |
| | logger.debug(f"Rendered prompt: {prompt!r}") |
| | except Exception as e: |
| | logger.warning(f"Prompt rendering failed: {e}") |
| | prompt = req.message |
| |
|
| | |
| | try: |
| | raw = llm(prompt, clf_output) |
| | try: |
| | |
| | unwrapped = json.loads(raw).get("result", raw) |
| | except Exception: |
| | unwrapped = raw |
| | result = unwrapped |
| | logger.debug(f"Classification result: {result}") |
| | except Exception as e: |
| | logger.error(f"Classification error: {e}. Falling back to: {req.class_set[-1]}") |
| | result = req.class_set[-1] |
| |
|
| | return Resp(result=result) |
| |
|