Mr-Help commited on
Commit
5ed8ade
·
verified ·
1 Parent(s): 0c83b5e

Create services/intent_classifier_client.py

Browse files
Files changed (1) hide show
  1. services/intent_classifier_client.py +132 -0
services/intent_classifier_client.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import threading
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, Gemma3ForCausalLM
8
+
9
+ from knowledge.classifier_prompt import (
10
+ build_system_prompt,
11
+ get_allowed_intents_for_state,
12
+ )
13
+
14
+ MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-1b-it")
15
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12"))
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+ ENABLE_MODEL_CLASSIFIER = os.getenv("ENABLE_MODEL_CLASSIFIER", "true").lower() == "true"
18
+
19
+ _model = None
20
+ _tokenizer = None
21
+ _model_lock = threading.Lock()
22
+
23
+
24
+ def _normalize_label(text: str, allowed_intents: list[str]) -> str:
25
+ cleaned = (text or "").strip().lower()
26
+ cleaned = cleaned.replace("```", "").replace("`", "").strip()
27
+
28
+ for intent in allowed_intents:
29
+ if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned):
30
+ return intent
31
+
32
+ return "unclear"
33
+
34
+
35
+ def _load_model_once():
36
+ global _model, _tokenizer
37
+
38
+ if _model is not None and _tokenizer is not None:
39
+ return _model, _tokenizer
40
+
41
+ with _model_lock:
42
+ if _model is not None and _tokenizer is not None:
43
+ return _model, _tokenizer
44
+
45
+ if not HF_TOKEN:
46
+ raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.")
47
+
48
+ print(f"[intent-classifier] loading model: {MODEL_ID}")
49
+
50
+ _tokenizer = AutoTokenizer.from_pretrained(
51
+ MODEL_ID,
52
+ token=HF_TOKEN
53
+ )
54
+
55
+ _model = Gemma3ForCausalLM.from_pretrained(
56
+ MODEL_ID,
57
+ token=HF_TOKEN
58
+ ).eval()
59
+
60
+ print("[intent-classifier] model loaded successfully")
61
+
62
+ return _model, _tokenizer
63
+
64
+
65
+ def _run_generation(user_message: str, state: str, flow_data: Optional[dict] = None) -> dict:
66
+ model, tokenizer = _load_model_once()
67
+
68
+ allowed_intents = get_allowed_intents_for_state(state)
69
+ system_prompt = build_system_prompt(
70
+ state=state,
71
+ flow_data=flow_data or {},
72
+ allowed_intents=allowed_intents,
73
+ )
74
+
75
+ messages = [
76
+ {"role": "system", "content": system_prompt},
77
+ {"role": "user", "content": user_message},
78
+ ]
79
+
80
+ prompt = tokenizer.apply_chat_template(
81
+ messages,
82
+ tokenize=False,
83
+ add_generation_prompt=True
84
+ )
85
+
86
+ inputs = tokenizer(prompt, return_tensors="pt")
87
+
88
+ with torch.inference_mode():
89
+ generation = model.generate(
90
+ **inputs,
91
+ max_new_tokens=MAX_NEW_TOKENS,
92
+ do_sample=False,
93
+ temperature=None,
94
+ top_p=None,
95
+ )
96
+
97
+ input_len = inputs["input_ids"].shape[-1]
98
+ generated_tokens = generation[0][input_len:]
99
+ raw_output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
100
+
101
+ final_intent = _normalize_label(raw_output, allowed_intents)
102
+
103
+ return {
104
+ "intent": final_intent,
105
+ "raw_output": raw_output,
106
+ "model": MODEL_ID,
107
+ "allowed_intents": allowed_intents,
108
+ }
109
+
110
+
111
+ def classify_message_with_model(user_message: str, state: str, flow_data: Optional[dict] = None) -> Optional[dict]:
112
+ """
113
+ Returns:
114
+ {
115
+ "intent": "...",
116
+ "raw_output": "...",
117
+ "model": "...",
118
+ "allowed_intents": [...]
119
+ }
120
+ or None if classifier is disabled
121
+ """
122
+ if not ENABLE_MODEL_CLASSIFIER:
123
+ return None
124
+
125
+ if not user_message or not user_message.strip():
126
+ return None
127
+
128
+ return _run_generation(
129
+ user_message=user_message.strip(),
130
+ state=state,
131
+ flow_data=flow_data or {},
132
+ )