File size: 2,738 Bytes
0d2a25f
 
ef8c7e3
 
9ffbce4
 
0d2a25f
 
 
9ffbce4
0d2a25f
 
ef8c7e3
0d2a25f
 
 
 
ef8c7e3
9ffbce4
 
 
ef8c7e3
 
 
0d2a25f
 
 
ef8c7e3
 
 
 
0d2a25f
 
 
 
ef8c7e3
 
 
9ffbce4
0d2a25f
ef8c7e3
 
 
0d2a25f
 
 
 
 
9ffbce4
ef8c7e3
9ffbce4
 
ef8c7e3
9ffbce4
ef8c7e3
 
 
 
 
 
 
 
9ffbce4
 
 
 
ef8c7e3
 
9ffbce4
ef8c7e3
9ffbce4
ef8c7e3
 
 
 
 
9ffbce4
ef8c7e3
9ffbce4
 
 
 
 
 
 
ef8c7e3
0d2a25f
9ffbce4
 
0d2a25f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import logging
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import json

import outlines
from outlines.models import openai
from outlines.types import Choice

# Configure logger
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("classifier")

app = FastAPI()

# Global variables for shared config and classifier
llm = None
clf_output = None
prompt_template = None
config_set = False

class Config(BaseModel):
    model_name: str
    base_url: str
    class_set: List[str]
    prompt_template: str

class Req(BaseModel):
    message: str

class Resp(BaseModel):
    result: str

@app.post("/config")
def configure(req: Config):
    """Receive and initialize classifier configuration."""
    global llm, clf_output, prompt_template, config_set

    if config_set:
        logger.warning("Classifier already configured. Ignoring new config.")
        return {"status": "already_configured"}

    api_key = os.getenv("TOGETHERAI_API_KEY")
    logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")

    try:
        # Instantiate model and choice output type
        llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
        clf_output = Choice(req.class_set)
        prompt_template = req.prompt_template
        config_set = True

        logger.info("Classifier configured successfully.")
        return {"status": "configured"}
    except Exception as e:
        logger.error(f"Failed to configure classifier: {e}")
        raise HTTPException(status_code=500, detail="Classifier configuration failed")

@app.post("/classify", response_model=Resp)
def classify(req: Req):
    """Run text classification using the configured LLM."""
    global llm, clf_output, prompt_template, config_set

    if not config_set or llm is None or clf_output is None:
        raise HTTPException(status_code=503, detail="Classifier not configured yet")

    # Render prompt
    try:
        prompt = prompt_template.replace("{message}", req.message)
        logger.debug(f"Rendered prompt: {prompt!r}")
    except Exception as e:
        logger.warning(f"Prompt rendering failed: {e}")
        prompt = req.message

    # Invoke LLM classifier
    try:
        raw = llm(prompt, clf_output)
        try:
            # llm returned '{"result": "Label"}' → parse it
            unwrapped = json.loads(raw).get("result", raw)
        except Exception:
            unwrapped = raw
        result = unwrapped
        logger.debug(f"Classification result: {result}")
    except Exception as e:
        logger.error(f"Classification error: {e}. Falling back to: {req.class_set[-1]}")
        result = req.class_set[-1]

    return Resp(result=result)