File size: 3,496 Bytes
5ed8ade | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import os
import re
import threading
from typing import Optional
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
from knowledge.classifier_prompt import (
build_system_prompt,
get_allowed_intents_for_state,
)
MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-1b-it")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12"))
HF_TOKEN = os.getenv("HF_TOKEN")
ENABLE_MODEL_CLASSIFIER = os.getenv("ENABLE_MODEL_CLASSIFIER", "true").lower() == "true"
_model = None
_tokenizer = None
_model_lock = threading.Lock()
def _normalize_label(text: str, allowed_intents: list[str]) -> str:
cleaned = (text or "").strip().lower()
cleaned = cleaned.replace("```", "").replace("`", "").strip()
for intent in allowed_intents:
if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned):
return intent
return "unclear"
def _load_model_once():
global _model, _tokenizer
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
with _model_lock:
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.")
print(f"[intent-classifier] loading model: {MODEL_ID}")
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN
)
_model = Gemma3ForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN
).eval()
print("[intent-classifier] model loaded successfully")
return _model, _tokenizer
def _run_generation(user_message: str, state: str, flow_data: Optional[dict] = None) -> dict:
model, tokenizer = _load_model_once()
allowed_intents = get_allowed_intents_for_state(state)
system_prompt = build_system_prompt(
state=state,
flow_data=flow_data or {},
allowed_intents=allowed_intents,
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=None,
top_p=None,
)
input_len = inputs["input_ids"].shape[-1]
generated_tokens = generation[0][input_len:]
raw_output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
final_intent = _normalize_label(raw_output, allowed_intents)
return {
"intent": final_intent,
"raw_output": raw_output,
"model": MODEL_ID,
"allowed_intents": allowed_intents,
}
def classify_message_with_model(user_message: str, state: str, flow_data: Optional[dict] = None) -> Optional[dict]:
"""
Returns:
{
"intent": "...",
"raw_output": "...",
"model": "...",
"allowed_intents": [...]
}
or None if classifier is disabled
"""
if not ENABLE_MODEL_CLASSIFIER:
return None
if not user_message or not user_message.strip():
return None
return _run_generation(
user_message=user_message.strip(),
state=state,
flow_data=flow_data or {},
) |