import os import re import threading from typing import Optional import torch from transformers import AutoTokenizer, Gemma3ForCausalLM from knowledge.classifier_prompt import ( build_system_prompt, get_allowed_intents_for_state, ) MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-1b-it") MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12")) HF_TOKEN = os.getenv("HF_TOKEN") ENABLE_MODEL_CLASSIFIER = os.getenv("ENABLE_MODEL_CLASSIFIER", "true").lower() == "true" _model = None _tokenizer = None _model_lock = threading.Lock() def _normalize_label(text: str, allowed_intents: list[str]) -> str: cleaned = (text or "").strip().lower() cleaned = cleaned.replace("```", "").replace("`", "").strip() for intent in allowed_intents: if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned): return intent return "unclear" def _load_model_once(): global _model, _tokenizer if _model is not None and _tokenizer is not None: return _model, _tokenizer with _model_lock: if _model is not None and _tokenizer is not None: return _model, _tokenizer if not HF_TOKEN: raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.") print(f"[intent-classifier] loading model: {MODEL_ID}") _tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, token=HF_TOKEN ) _model = Gemma3ForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN ).eval() print("[intent-classifier] model loaded successfully") return _model, _tokenizer def _run_generation(user_message: str, state: str, flow_data: Optional[dict] = None) -> dict: model, tokenizer = _load_model_once() allowed_intents = get_allowed_intents_for_state(state) system_prompt = build_system_prompt( state=state, flow_data=flow_data or {}, allowed_intents=allowed_intents, ) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt") with torch.inference_mode(): generation = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, temperature=None, top_p=None, ) input_len = inputs["input_ids"].shape[-1] generated_tokens = generation[0][input_len:] raw_output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() final_intent = _normalize_label(raw_output, allowed_intents) return { "intent": final_intent, "raw_output": raw_output, "model": MODEL_ID, "allowed_intents": allowed_intents, } def classify_message_with_model(user_message: str, state: str, flow_data: Optional[dict] = None) -> Optional[dict]: """ Returns: { "intent": "...", "raw_output": "...", "model": "...", "allowed_intents": [...] } or None if classifier is disabled """ if not ENABLE_MODEL_CLASSIFIER: return None if not user_message or not user_message.strip(): return None return _run_generation( user_message=user_message.strip(), state=state, flow_data=flow_data or {}, )