Spaces:
Sleeping
Sleeping
File size: 7,341 Bytes
5af0c50 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """
Model-based intent classifier for Query Router.
Replaces brittle rule-based heuristics with a trained classifier.
Backends: tfidf (default), fasttext, distilbert.
Intents: small_to_big (detail), fast (keyword), deep (natural language)
"""
import logging
from pathlib import Path
from typing import Optional
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
logger = logging.getLogger(__name__)
INTENTS = ["small_to_big", "fast", "deep"]
class IntentClassifier:
"""
Intent classifier with pluggable backends:
- tfidf: TF-IDF + LogisticRegression (~1–2ms)
- fasttext: FastText (~1ms, requires fasttext package)
- distilbert: Zero-shot DistilBERT (~50–100ms, higher accuracy)
"""
def __init__(self, model_path: Optional[Path] = None):
self.pipeline: Optional[Pipeline] = None
self._fasttext_model = None
self._distilbert_pipeline = None
self._backend = "tfidf"
self.model_path = Path(model_path) if model_path else None
def load(self, path: Optional[Path] = None) -> bool:
"""Load trained model from disk."""
p = path or self.model_path
if not p:
return False
p = Path(p)
base = p.parent if p.suffix in (".pkl", ".bin") else p
pkl_path = p if p.suffix == ".pkl" else base / "intent_classifier.pkl"
bin_path = p if p.suffix == ".bin" else base / "intent_classifier.bin"
# Try .pkl first (tfidf or distilbert)
if pkl_path.exists():
try:
data = joblib.load(pkl_path)
if isinstance(data, dict):
self.pipeline = data.get("pipeline")
self._backend = data.get("backend", "tfidf")
if self._backend == "distilbert":
self._load_distilbert(data)
elif self.pipeline is None and self._backend == "tfidf":
self.pipeline = data
else:
self.pipeline = data
self.model_path = pkl_path
logger.info("Intent classifier loaded from %s (backend=%s)", pkl_path, self._backend)
return True
except Exception as e:
logger.warning("Failed to load intent classifier: %s", e)
# Try .bin (FastText)
if bin_path.exists():
try:
import fasttext
self._fasttext_model = fasttext.load_model(str(bin_path))
self._backend = "fasttext"
self.model_path = bin_path
logger.info("Intent classifier loaded from %s (FastText)", bin_path)
return True
except ImportError:
logger.warning("FastText not installed; pip install fasttext")
except Exception as e:
logger.warning("Failed to load FastText: %s", e)
return False
def _load_distilbert(self, data: dict) -> None:
"""Lazy-load DistilBERT pipeline from saved config."""
model_name = data.get("distilbert_model", "distilbert-base-uncased")
try:
from transformers import pipeline
self._distilbert_pipeline = pipeline(
"zero-shot-classification",
model=model_name,
device=-1,
)
except Exception as e:
logger.warning("DistilBERT pipeline load failed: %s", e)
self.pipeline = None # Use distilbert, not sklearn pipeline
def predict(self, query: str) -> str:
"""Predict intent for a query. Returns one of small_to_big, fast, deep."""
q = query.strip()
if not q:
return "deep"
if self._fasttext_model is not None:
pred = self._fasttext_model.predict(q)
return pred[0][0].replace("__label__", "")
if self._distilbert_pipeline is not None:
out = self._distilbert_pipeline(q, INTENTS, multi_label=False)
return out["labels"][0]
if self.pipeline is None:
raise RuntimeError("Intent classifier not loaded; call load() first")
return str(self.pipeline.predict([q])[0])
def predict_proba(self, query: str) -> dict[str, float]:
"""Return intent probabilities for debugging."""
q = query.strip()
if not q:
return {i: 1.0 / len(INTENTS) for i in INTENTS}
if self._fasttext_model is not None:
pred = self._fasttext_model.predict(q, k=len(INTENTS))
return dict(zip([l.replace("__label__", "") for l in pred[0]], pred[1]))
if self._distilbert_pipeline is not None:
out = self._distilbert_pipeline(q, INTENTS, multi_label=False)
return dict(zip(out["labels"], out["scores"]))
if self.pipeline is None:
raise RuntimeError("Intent classifier not loaded")
probs = self.pipeline.predict_proba([q])[0]
last_step = self.pipeline.steps[-1][1]
classes = getattr(last_step, "classes_", INTENTS)
return dict(zip(classes, probs))
def train_classifier(
queries: list[str],
labels: list[str],
max_features: int = 5000,
C: float = 1.0,
backend: str = "tfidf",
):
"""
Train intent classifier. Returns pipeline (tfidf), model (fasttext), or dict (distilbert).
"""
if backend == "fasttext":
return _train_fasttext(queries, labels)
if backend == "distilbert":
return _train_distilbert(queries, labels)
# tfidf default
pipeline = Pipeline([
("tfidf", TfidfVectorizer(
max_features=max_features,
ngram_range=(1, 2),
min_df=1,
lowercase=True,
)),
("clf", LogisticRegression(
C=C,
max_iter=500,
class_weight="balanced",
random_state=42,
)),
])
pipeline.fit(queries, labels)
return pipeline
def _train_fasttext(queries: list[str], labels: list[str]):
"""Train FastText classifier. Requires fasttext package."""
try:
import fasttext
import tempfile
import os
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
for q, l in zip(queries, labels):
line = q.replace("\n", " ").strip()
f.write(f"__label__{l} {line}\n")
path = f.name
model = fasttext.train_supervised(path, epoch=25, lr=0.5, wordNgrams=2)
os.unlink(path)
return model
except ImportError:
raise RuntimeError("FastText not installed: pip install fasttext")
def _train_distilbert(queries: list[str], labels: list[str]) -> dict:
"""DistilBERT zero-shot: creates pipeline (no training). Saves config for inference."""
try:
from transformers import pipeline
pipe = pipeline(
"zero-shot-classification",
model="distilbert-base-uncased",
device=-1,
)
return {
"backend": "distilbert",
"distilbert_model": "distilbert-base-uncased",
"intents": INTENTS,
}
except Exception as e:
raise RuntimeError(f"DistilBERT setup failed: {e}")
|