File size: 3,496 Bytes
5ed8ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import re
import threading
from typing import Optional

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM

from knowledge.classifier_prompt import (
    build_system_prompt,
    get_allowed_intents_for_state,
)

MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-1b-it")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12"))
HF_TOKEN = os.getenv("HF_TOKEN")
ENABLE_MODEL_CLASSIFIER = os.getenv("ENABLE_MODEL_CLASSIFIER", "true").lower() == "true"

_model = None
_tokenizer = None
_model_lock = threading.Lock()


def _normalize_label(text: str, allowed_intents: list[str]) -> str:
    cleaned = (text or "").strip().lower()
    cleaned = cleaned.replace("```", "").replace("`", "").strip()

    for intent in allowed_intents:
        if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned):
            return intent

    return "unclear"


def _load_model_once():
    global _model, _tokenizer

    if _model is not None and _tokenizer is not None:
        return _model, _tokenizer

    with _model_lock:
        if _model is not None and _tokenizer is not None:
            return _model, _tokenizer

        if not HF_TOKEN:
            raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.")

        print(f"[intent-classifier] loading model: {MODEL_ID}")

        _tokenizer = AutoTokenizer.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN
        )

        _model = Gemma3ForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN
        ).eval()

        print("[intent-classifier] model loaded successfully")

    return _model, _tokenizer


def _run_generation(user_message: str, state: str, flow_data: Optional[dict] = None) -> dict:
    model, tokenizer = _load_model_once()

    allowed_intents = get_allowed_intents_for_state(state)
    system_prompt = build_system_prompt(
        state=state,
        flow_data=flow_data or {},
        allowed_intents=allowed_intents,
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message},
    ]

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(prompt, return_tensors="pt")

    with torch.inference_mode():
        generation = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=None,
            top_p=None,
        )

    input_len = inputs["input_ids"].shape[-1]
    generated_tokens = generation[0][input_len:]
    raw_output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

    final_intent = _normalize_label(raw_output, allowed_intents)

    return {
        "intent": final_intent,
        "raw_output": raw_output,
        "model": MODEL_ID,
        "allowed_intents": allowed_intents,
    }


def classify_message_with_model(user_message: str, state: str, flow_data: Optional[dict] = None) -> Optional[dict]:
    """
    Returns:
    {
      "intent": "...",
      "raw_output": "...",
      "model": "...",
      "allowed_intents": [...]
    }
    or None if classifier is disabled
    """
    if not ENABLE_MODEL_CLASSIFIER:
        return None

    if not user_message or not user_message.strip():
        return None

    return _run_generation(
        user_message=user_message.strip(),
        state=state,
        flow_data=flow_data or {},
    )