| import os |
| import re |
| import threading |
| from typing import Optional |
|
|
| import torch |
| from transformers import AutoTokenizer, Gemma3ForCausalLM |
|
|
| from knowledge.classifier_prompt import ( |
| build_system_prompt, |
| get_allowed_intents_for_state, |
| ) |
|
|
| MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-1b-it") |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12")) |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| ENABLE_MODEL_CLASSIFIER = os.getenv("ENABLE_MODEL_CLASSIFIER", "true").lower() == "true" |
|
|
| _model = None |
| _tokenizer = None |
| _model_lock = threading.Lock() |
|
|
|
|
| def _normalize_label(text: str, allowed_intents: list[str]) -> str: |
| cleaned = (text or "").strip().lower() |
| cleaned = cleaned.replace("```", "").replace("`", "").strip() |
|
|
| for intent in allowed_intents: |
| if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned): |
| return intent |
|
|
| return "unclear" |
|
|
|
|
| def _load_model_once(): |
| global _model, _tokenizer |
|
|
| if _model is not None and _tokenizer is not None: |
| return _model, _tokenizer |
|
|
| with _model_lock: |
| if _model is not None and _tokenizer is not None: |
| return _model, _tokenizer |
|
|
| if not HF_TOKEN: |
| raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.") |
|
|
| print(f"[intent-classifier] loading model: {MODEL_ID}") |
|
|
| _tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_ID, |
| token=HF_TOKEN |
| ) |
|
|
| _model = Gemma3ForCausalLM.from_pretrained( |
| MODEL_ID, |
| token=HF_TOKEN |
| ).eval() |
|
|
| print("[intent-classifier] model loaded successfully") |
|
|
| return _model, _tokenizer |
|
|
|
|
| def _run_generation(user_message: str, state: str, flow_data: Optional[dict] = None) -> dict: |
| model, tokenizer = _load_model_once() |
|
|
| allowed_intents = get_allowed_intents_for_state(state) |
| system_prompt = build_system_prompt( |
| state=state, |
| flow_data=flow_data or {}, |
| allowed_intents=allowed_intents, |
| ) |
|
|
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_message}, |
| ] |
|
|
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| inputs = tokenizer(prompt, return_tensors="pt") |
|
|
| with torch.inference_mode(): |
| generation = model.generate( |
| **inputs, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=False, |
| temperature=None, |
| top_p=None, |
| ) |
|
|
| input_len = inputs["input_ids"].shape[-1] |
| generated_tokens = generation[0][input_len:] |
| raw_output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
|
|
| final_intent = _normalize_label(raw_output, allowed_intents) |
|
|
| return { |
| "intent": final_intent, |
| "raw_output": raw_output, |
| "model": MODEL_ID, |
| "allowed_intents": allowed_intents, |
| } |
|
|
|
|
| def classify_message_with_model(user_message: str, state: str, flow_data: Optional[dict] = None) -> Optional[dict]: |
| """ |
| Returns: |
| { |
| "intent": "...", |
| "raw_output": "...", |
| "model": "...", |
| "allowed_intents": [...] |
| } |
| or None if classifier is disabled |
| """ |
| if not ENABLE_MODEL_CLASSIFIER: |
| return None |
|
|
| if not user_message or not user_message.strip(): |
| return None |
|
|
| return _run_generation( |
| user_message=user_message.strip(), |
| state=state, |
| flow_data=flow_data or {}, |
| ) |