ADK-Bot / services /intent_classifier_client.py
Mr-Help's picture
Create services/intent_classifier_client.py
5ed8ade verified
import os
import re
import threading
from typing import Optional
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
from knowledge.classifier_prompt import (
build_system_prompt,
get_allowed_intents_for_state,
)
MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-1b-it")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12"))
HF_TOKEN = os.getenv("HF_TOKEN")
ENABLE_MODEL_CLASSIFIER = os.getenv("ENABLE_MODEL_CLASSIFIER", "true").lower() == "true"
_model = None
_tokenizer = None
_model_lock = threading.Lock()
def _normalize_label(text: str, allowed_intents: list[str]) -> str:
cleaned = (text or "").strip().lower()
cleaned = cleaned.replace("```", "").replace("`", "").strip()
for intent in allowed_intents:
if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned):
return intent
return "unclear"
def _load_model_once():
global _model, _tokenizer
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
with _model_lock:
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.")
print(f"[intent-classifier] loading model: {MODEL_ID}")
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN
)
_model = Gemma3ForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN
).eval()
print("[intent-classifier] model loaded successfully")
return _model, _tokenizer
def _run_generation(user_message: str, state: str, flow_data: Optional[dict] = None) -> dict:
model, tokenizer = _load_model_once()
allowed_intents = get_allowed_intents_for_state(state)
system_prompt = build_system_prompt(
state=state,
flow_data=flow_data or {},
allowed_intents=allowed_intents,
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=None,
top_p=None,
)
input_len = inputs["input_ids"].shape[-1]
generated_tokens = generation[0][input_len:]
raw_output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
final_intent = _normalize_label(raw_output, allowed_intents)
return {
"intent": final_intent,
"raw_output": raw_output,
"model": MODEL_ID,
"allowed_intents": allowed_intents,
}
def classify_message_with_model(user_message: str, state: str, flow_data: Optional[dict] = None) -> Optional[dict]:
"""
Returns:
{
"intent": "...",
"raw_output": "...",
"model": "...",
"allowed_intents": [...]
}
or None if classifier is disabled
"""
if not ENABLE_MODEL_CLASSIFIER:
return None
if not user_message or not user_message.strip():
return None
return _run_generation(
user_message=user_message.strip(),
state=state,
flow_data=flow_data or {},
)